returnn 1.20230501.124840__tar.gz → 1.20230503.142906__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {returnn-1.20230501.124840/returnn.egg-info → returnn-1.20230503.142906}/PKG-INFO +1 -1
- returnn-1.20230503.142906/_setup_info_generated.py +2 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/_backend.py +14 -1
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/cond.py +7 -3
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/const.py +11 -1
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/parameter.py +6 -1
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/__init__.py +1 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/_backend.py +9 -1
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/_utils.py +19 -0
- returnn-1.20230503.142906/returnn/tf/frontend_layers/cond.py +252 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/layers/basic.py +40 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/util/basic.py +7 -5
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/frontend/_backend.py +8 -3
- {returnn-1.20230501.124840 → returnn-1.20230503.142906/returnn.egg-info}/PKG-INFO +1 -1
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn.egg-info/SOURCES.txt +2 -0
- returnn-1.20230503.142906/tests/test_rf_cond.py +150 -0
- returnn-1.20230501.124840/_setup_info_generated.py +0 -2
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/.editorconfig +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/.gitignore +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/.gitmodules +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/.kateconfig +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/CHANGELOG.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/CODEOWNERS +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/CONTRIBUTING.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/LICENSE +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/MANIFEST.in +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/README.rst +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/12AX.cluster_map +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/_setup_returnn_env.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-fwd.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-horovod-mpi.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-horovod-mpi.py.sh +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-horovod-mpi.sh +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-hyper-param-tuning.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-iter-dataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-list-devices.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-lua-torch-layer.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-pretrain.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-record-and-push-to-webserver.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-returnn-as-framework.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-rf.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-rhn-enwik8.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-sprint-interface.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-att-copy.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-attention.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-enc-dec.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-hard-att-copy.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-lstm-benchmark.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-native-lstm.12ax.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-native-lstm2.12ax.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-neural-transducer.12ax.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-rec-explicit-lstm.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-rec-explicit-rnn.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-rec-self-att.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-search-compiled-graph.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-timit-lstm-ctc.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-torch.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/demo.sh +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/README.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/chars.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/config_demo +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/config_fwd +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/config_real +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/decode.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/go.sh +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/lines.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/split/eval.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/split/train.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/IAM/split/valid.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/README.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/artificial/create_test_h5.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/artificial/forwardconfig +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/artificial/go.sh +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/artificial/trainconfig +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/artificial_rgb/go.sh +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/pyproject.toml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/requirements.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/__main__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/__old_mod_loader__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/__setup__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/config.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/audio.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/basic.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/bundle_file.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/cached.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/cached2.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/generating.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/hdf.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/lm.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/map.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/meta.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/multi_proc.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/normalization_data.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/numpy_dump.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/raw_wav.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/sprint.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/stereo.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/util/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/util/feature_extraction.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/datasets/util/vocabulary.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/engine/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/engine/base.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/engine/batch.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/__main__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/.git +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/graph_editor/README.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/graph_editor/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/graph_editor/edit.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/graph_editor/reroute.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/graph_editor/select.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/graph_editor/subgraph.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/graph_editor/transform.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/extern/graph_editor/util.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/_numpy_backend.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/_utils.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/array_.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/attention.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/container.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/control_flow_ctx.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/conv.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/device.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/dims.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/dropout.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/dtype.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/encoder/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/encoder/base.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/encoder/conformer.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/gradient.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/init.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/linear.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/loop.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/loss.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/math_.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/matmul.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/module.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/normalization.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/rand.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/rec.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/reduce.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/run_ctx.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/signal.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/state.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/frontend/types.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/import_/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/import_/common.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/import_/git.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/import_/import_.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/learning_rate_control.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/log.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/native_op.cpp +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/native_op.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/pretrain.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/sprint/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/sprint/cache.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/sprint/control.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/sprint/error_signals.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/sprint/extern_interface.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/sprint/interface.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/README.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/_dim_extra.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/_tensor_extra.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/_tensor_mixin_base.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/_tensor_op_overloads.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/control_flow_ctx.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/dim.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/marked_dim.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/tensor.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tensor/tensor_dict.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/compat.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/data_pipeline.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/distributed.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/engine.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/dims.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/layer.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/make_layer.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_low_level/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_low_level/_backend.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/horovod.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/hyper_param_tuning.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/layers/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/layers/base.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/layers/rec.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/layers/segmental_model.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/layers/signal_processing.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/native_op.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/network.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/sprint.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/updater.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/util/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/util/data.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/util/ken_lm.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/util/open_fst.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/README.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/data/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/data/pipeline.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/data/tensor_utils.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/engine.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/frontend/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/frontend/_rand.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/frontend/bridge.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/functional/README.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/functional/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/torch/updater.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/__init__.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/basic.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/better_exchook.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/bpe.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/debug.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/debug_helpers.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/fsa.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/literal_py_to_pickle.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/pprint.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/py-to-pickle.cpp +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/py_compat.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/sig_proc.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/util/task_system.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn.egg-info/dependency_links.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn.egg-info/top_level.txt +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/rnn.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/setup.cfg +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/setup.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/DummySprintExec.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm-inspection-profile.xml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/.gitignore +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/.name +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/misc.xml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/modules.xml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/returnn.iml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/_set_num_threads1.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/_setup_returnn_env.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/_setup_test_env.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/bpe-unicode-demo.codes +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/bpe-unicode-demo.vocab +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/lexicon_opt.fst +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/lexicon_opt.isyms +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/lexicon_opt.jpg +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/lexicon_opt.osyms +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/lint_common.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/pycharm-inspect.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/pylint.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/returnn-as-framework.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/rf_utils.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/spelling.dic +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_Config.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_Dataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_Fsa.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_GeneratingDataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_HDFDataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_LearningRateControl.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_Log.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_MultiProcDataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_PTDataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_Pretrain.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_ResNet.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_SprintDataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_SprintInterface.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TFEngine.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TFNativeOp.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TFNetworkLayer.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TFNetworkRecLayer.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TFNetworkSigProcLayer.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TFUpdater.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TFUtil.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TF_determinism.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TaskSystem.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TaskSystem_SharedMem.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_TranslationDataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_Util.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_demos.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_fork_exec.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_hdf_dump.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_array.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_attention.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_base.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_container.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_conv.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_encoder_conformer.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_math.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_normalization.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_rec.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_rf_signal.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_tensor.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_tools.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_torch_frontend.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tests/test_torch_internal_frontend.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/_setup_returnn_env.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/analyze-dataset-batches.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/bliss-collect-seq-lens.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/bliss-dump-text.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/bliss-get-segment-names.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/bliss-to-ogg-zip.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/bpe-create-lexicon.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/calculate-word-error-rate.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/cleanup-old-models.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/collect-orth-symbols.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/collect-words.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/compile_native_op.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/compile_tf_graph.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/debug-dump-search-scores.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/debug-plot-search-scores.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/dump-dataset-raw-strings.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/dump-dataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/dump-forward-stats.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/dump-forward.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/dump-network-json.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/dump-pickle.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/extract_state_tying_from_dataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/get-attention-weights.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/get-best-model-epoch.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/hdf_dump.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/hdf_dump_translation_dataset.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/import-blocks-mt-model.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/import-t2t-mt-model.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/.gitignore +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/Makefile +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/README.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/example/README.md +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/example/libs_list +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/example/state_vars_list +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/example/tensor_names_list +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/file.h +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/main.cc +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/rescorer.h +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/vocabulary.cc +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/lattice_rescorer/vocabulary.h +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/tf_avg_checkpoints.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/tf_inspect_checkpoint.py +0 -0
- {returnn-1.20230501.124840 → returnn-1.20230503.142906}/tools/tf_inspect_summary_log.py +0 -0
|
@@ -3,7 +3,7 @@ Backends for the frontend API
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from typing import TYPE_CHECKING, Optional, Any, Union, TypeVar, Generic, Type, Sequence, Dict, Tuple
|
|
6
|
+
from typing import TYPE_CHECKING, Optional, Any, Union, TypeVar, Generic, Type, Callable, Sequence, Dict, Tuple
|
|
7
7
|
import contextlib
|
|
8
8
|
import numpy
|
|
9
9
|
import returnn.frontend as rf
|
|
@@ -40,6 +40,19 @@ class Backend(Generic[T]):
|
|
|
40
40
|
"""
|
|
41
41
|
raise NotImplementedError
|
|
42
42
|
|
|
43
|
+
@staticmethod
|
|
44
|
+
def cond(pred: Tensor, true_fn: Callable, false_fn: Callable):
|
|
45
|
+
"""
|
|
46
|
+
cond: conditional execution.
|
|
47
|
+
|
|
48
|
+
Note that this does not need an implementation for eager-based frameworks
|
|
49
|
+
(:func:`executing_eagerly` returns True),
|
|
50
|
+
as the :func:`returnn.frontend.cond` function already covers that case.
|
|
51
|
+
"""
|
|
52
|
+
# noinspection PyProtectedMember
|
|
53
|
+
assert not pred._raw_backend.executing_eagerly(), "should not get here"
|
|
54
|
+
raise NotImplementedError
|
|
55
|
+
|
|
43
56
|
@staticmethod
|
|
44
57
|
def set_random_seed(seed: int):
|
|
45
58
|
"""
|
|
@@ -8,11 +8,14 @@ https://github.com/rwth-i6/returnn/issues/1282
|
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
10
|
from __future__ import annotations
|
|
11
|
-
from typing import Union
|
|
11
|
+
from typing import Union, TypeVar, Callable
|
|
12
12
|
from returnn.tensor import Tensor
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
T = TypeVar("T")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def cond(pred: Union[bool, Tensor], true_fn: Callable[[], T], false_fn: Callable[[], T]) -> T:
|
|
16
19
|
"""
|
|
17
20
|
:param pred:
|
|
18
21
|
:param true_fn:
|
|
@@ -24,6 +27,7 @@ def cond(pred: Union[bool, Tensor], true_fn, false_fn):
|
|
|
24
27
|
return true_fn()
|
|
25
28
|
else:
|
|
26
29
|
return false_fn()
|
|
30
|
+
assert isinstance(pred, Tensor) and pred.dims == () and pred.dtype == "bool"
|
|
27
31
|
# noinspection PyProtectedMember
|
|
28
32
|
backend = pred._raw_backend
|
|
29
33
|
if backend.executing_eagerly():
|
|
@@ -31,4 +35,4 @@ def cond(pred: Union[bool, Tensor], true_fn, false_fn):
|
|
|
31
35
|
return true_fn()
|
|
32
36
|
else:
|
|
33
37
|
return false_fn()
|
|
34
|
-
|
|
38
|
+
return backend.cond(pred, true_fn, false_fn)
|
|
@@ -10,7 +10,7 @@ from ._backend import global_backend
|
|
|
10
10
|
import returnn.frontend as rf
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
__all__ = ["full", "constant", "fill", "zeros", "ones"]
|
|
13
|
+
__all__ = ["full", "constant", "fill", "zeros", "ones", "zeros_like", "ones_like"]
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def full(
|
|
@@ -54,3 +54,13 @@ def ones(dims: Sequence[Dim], *, dtype: Optional[str] = None, sparse_dim: Option
|
|
|
54
54
|
ones. float by default.
|
|
55
55
|
"""
|
|
56
56
|
return full(dims=dims, fill_value=1, dtype=dtype or rf.get_default_float_dtype(), sparse_dim=sparse_dim)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def zeros_like(other: Tensor) -> Tensor:
|
|
60
|
+
"""zeros like other"""
|
|
61
|
+
return zeros(dims=other.dims, dtype=other.dtype, sparse_dim=other.sparse_dim)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def ones_like(other: Tensor) -> Tensor:
|
|
65
|
+
"""ones like other"""
|
|
66
|
+
return ones(dims=other.dims, dtype=other.dtype, sparse_dim=other.sparse_dim)
|
|
@@ -138,7 +138,12 @@ class Parameter(Tensor[T]):
|
|
|
138
138
|
def trainable(self, trainable: Optional[bool]):
|
|
139
139
|
self._trainable = trainable
|
|
140
140
|
if trainable is None:
|
|
141
|
-
|
|
141
|
+
if self.auxiliary:
|
|
142
|
+
trainable = False
|
|
143
|
+
elif self.dtype.startswith("int"):
|
|
144
|
+
trainable = False
|
|
145
|
+
else:
|
|
146
|
+
trainable = True
|
|
142
147
|
self._raw_backend.set_parameter_trainable(self, trainable)
|
|
143
148
|
|
|
144
149
|
@property
|
{returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/_backend.py
RENAMED
|
@@ -3,7 +3,7 @@ High-level backend for RETURNN layers
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from typing import Union, Sequence, Optional, Any, Tuple, Dict
|
|
6
|
+
from typing import Union, Sequence, Optional, Any, Callable, Tuple, Dict
|
|
7
7
|
import contextlib
|
|
8
8
|
import numpy
|
|
9
9
|
import tensorflow as tf
|
|
@@ -37,6 +37,14 @@ class ReturnnLayersBackend(Backend[Layer]):
|
|
|
37
37
|
"""executing eagerly"""
|
|
38
38
|
return False
|
|
39
39
|
|
|
40
|
+
@staticmethod
|
|
41
|
+
def cond(pred: Tensor, true_fn: Callable, false_fn: Callable):
|
|
42
|
+
"""cond"""
|
|
43
|
+
with rfl.Cond(pred) as cond:
|
|
44
|
+
cond.true = true_fn()
|
|
45
|
+
cond.false = false_fn()
|
|
46
|
+
return cond.result
|
|
47
|
+
|
|
40
48
|
@staticmethod
|
|
41
49
|
def set_random_seed(seed: int):
|
|
42
50
|
"""
|
{returnn-1.20230501.124840 → returnn-1.20230503.142906}/returnn/tf/frontend_layers/_utils.py
RENAMED
|
@@ -30,6 +30,25 @@ def copy(tensor: Tensor[rfl.Layer], *, name: Union[rfl.Layer, str]) -> Tensor[rf
|
|
|
30
30
|
return rfl.make_layer({"class": "copy", "from": tensor}, name=name)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
def mark_as_output_in_scope(tensor: Tensor, scope: rfl.Layer) -> Tensor:
|
|
34
|
+
"""
|
|
35
|
+
Mark this as an output.
|
|
36
|
+
"""
|
|
37
|
+
assert tensor.raw_tensor.layer_dict, f"mark_as_output can only be called on a layer, not a layer-ref {tensor}."
|
|
38
|
+
res = tensor
|
|
39
|
+
if tensor.raw_tensor is scope.children.get("output"):
|
|
40
|
+
pass # not needed
|
|
41
|
+
elif tensor.raw_tensor.parent is not scope:
|
|
42
|
+
res = copy(tensor, name=scope.get_new_child(suggested_name=tensor.raw_tensor.get_abs_name(join_str="_")))
|
|
43
|
+
res.raw_tensor.layer_dict["is_output_layer"] = True
|
|
44
|
+
else:
|
|
45
|
+
assert tensor.raw_tensor.parent is scope
|
|
46
|
+
assert tensor.raw_tensor.layer_dict
|
|
47
|
+
tensor.raw_tensor.layer_dict["is_output_layer"] = True
|
|
48
|
+
scope.marked_outputs.append(res)
|
|
49
|
+
return res
|
|
50
|
+
|
|
51
|
+
|
|
33
52
|
def get_last_hidden_state(
|
|
34
53
|
source: Tensor,
|
|
35
54
|
*,
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Conditional logic
|
|
3
|
+
|
|
4
|
+
https://github.com/rwth-i6/returnn_common/issues/24
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
from typing import List, TypeVar, Generic
|
|
9
|
+
from tensorflow.python.util import nest
|
|
10
|
+
from returnn.tensor import Tensor, ControlFlowContext
|
|
11
|
+
import returnn.frontend as rf
|
|
12
|
+
import returnn.tf.frontend_layers as rfl
|
|
13
|
+
from . import _utils
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__all__ = ["Cond", "CondModule"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Cond(Generic[T]):
|
|
23
|
+
"""
|
|
24
|
+
Conditional branching. Basically behaves like ``if ... else ...``.
|
|
25
|
+
Only one branch will be executed, and the condition needs to be a bool scalar.
|
|
26
|
+
This wraps to :class:`CondLayer` in RETURNN and to ``tf.cond`` in TensorFlow.
|
|
27
|
+
|
|
28
|
+
Example::
|
|
29
|
+
|
|
30
|
+
with Cond(cond) as cond_obj:
|
|
31
|
+
cond_obj.true = mod_true_case(x)
|
|
32
|
+
cond_obj.false = mod_false_case(x)
|
|
33
|
+
y = cond_obj.result
|
|
34
|
+
|
|
35
|
+
Corresponds to::
|
|
36
|
+
|
|
37
|
+
if cond:
|
|
38
|
+
y = mod_true_case(x)
|
|
39
|
+
else:
|
|
40
|
+
y = mod_false_case(x)
|
|
41
|
+
|
|
42
|
+
The context scope has two states corresponding to the True and False computation branch.
|
|
43
|
+
The initial state is the True branch.
|
|
44
|
+
Assigning ``cond_obj.true`` has the side effect of switching the computation to the False branch.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, condition: Tensor, *, name: str = "cond"):
|
|
48
|
+
self.condition = condition
|
|
49
|
+
self._entered = False
|
|
50
|
+
self._entered_state = True
|
|
51
|
+
self._true_value = None
|
|
52
|
+
self._false_value = None
|
|
53
|
+
self._result_value = None
|
|
54
|
+
self.layer_module = CondModule(cond=self)
|
|
55
|
+
self.name_ctx = rfl.Layer(
|
|
56
|
+
module=self.layer_module, suggested_name=name, parent=rfl.Layer.current_ctx(), can_access_children=False
|
|
57
|
+
)
|
|
58
|
+
self.name_ctx.custom_layer_name_scope = ""
|
|
59
|
+
self.true_branch_control_flow_ctx = ControlFlowContext(
|
|
60
|
+
kind=ControlFlowContext.Types.Cond, outer_ctx=self.name_ctx.control_flow_ctx()
|
|
61
|
+
)
|
|
62
|
+
self.true_branch_name_ctx = rfl.Layer(
|
|
63
|
+
module=self.layer_module,
|
|
64
|
+
suggested_name="true",
|
|
65
|
+
parent=self.name_ctx,
|
|
66
|
+
virtual=True,
|
|
67
|
+
can_access_children=False,
|
|
68
|
+
new_control_flow_ctx=self.true_branch_control_flow_ctx,
|
|
69
|
+
)
|
|
70
|
+
self.true_branch_name_ctx.is_subnet = True
|
|
71
|
+
self.true_branch_name_ctx.extend_reserved_names({"output"})
|
|
72
|
+
self.false_branch_control_flow_ctx = ControlFlowContext(
|
|
73
|
+
kind=ControlFlowContext.Types.Cond, outer_ctx=self.name_ctx.control_flow_ctx()
|
|
74
|
+
)
|
|
75
|
+
self.false_branch_name_ctx = rfl.Layer(
|
|
76
|
+
module=self.layer_module,
|
|
77
|
+
suggested_name="false",
|
|
78
|
+
parent=self.name_ctx,
|
|
79
|
+
virtual=True,
|
|
80
|
+
can_access_children=False,
|
|
81
|
+
new_control_flow_ctx=self.false_branch_control_flow_ctx,
|
|
82
|
+
)
|
|
83
|
+
self.false_branch_name_ctx.is_subnet = True
|
|
84
|
+
self.false_branch_name_ctx.extend_reserved_names({"output"})
|
|
85
|
+
|
|
86
|
+
def __repr__(self):
|
|
87
|
+
return f"Cond{self.name_ctx}"
|
|
88
|
+
|
|
89
|
+
def __enter__(self):
|
|
90
|
+
assert not self._entered, f"{self} cannot enter twice"
|
|
91
|
+
self._entered = True
|
|
92
|
+
self._entered_state = True
|
|
93
|
+
self.true_branch_name_ctx.__enter__()
|
|
94
|
+
return self
|
|
95
|
+
|
|
96
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
97
|
+
# First exit any scopes and do cleanup without throwing any exceptions.
|
|
98
|
+
if self._entered:
|
|
99
|
+
if self._true_value is None:
|
|
100
|
+
self.true_branch_name_ctx.__exit__(exc_type, exc_val, exc_tb)
|
|
101
|
+
elif self._false_value is None:
|
|
102
|
+
self.false_branch_name_ctx.__exit__(exc_type, exc_val, exc_tb)
|
|
103
|
+
if not exc_type: # only do error checking if there was no other exception
|
|
104
|
+
assert self._entered
|
|
105
|
+
assert self._true_value is not None, f"{self} you need to call else_()"
|
|
106
|
+
assert self._false_value is not None, f"{self} you need to call end()"
|
|
107
|
+
self._entered = False
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def true(self) -> T:
|
|
111
|
+
"""
|
|
112
|
+
The getter usually would not be used.
|
|
113
|
+
"""
|
|
114
|
+
return self._true_value
|
|
115
|
+
|
|
116
|
+
@true.setter
|
|
117
|
+
def true(self, true_value: T):
|
|
118
|
+
"""
|
|
119
|
+
Defines the True branch value.
|
|
120
|
+
Enter the False branch.
|
|
121
|
+
Assign self.false afterwards.
|
|
122
|
+
"""
|
|
123
|
+
assert self._entered, f"{self} you need to be in the context scope"
|
|
124
|
+
assert self._entered_state is True, f"{self} you cannot enter the else branch twice"
|
|
125
|
+
assert true_value is not None
|
|
126
|
+
assert self._true_value is None
|
|
127
|
+
if isinstance(true_value, Tensor):
|
|
128
|
+
true_value = _utils.copy(true_value, name=self.true_branch_name_ctx.get_child("output"))
|
|
129
|
+
else:
|
|
130
|
+
values_flat = nest.flatten(true_value) # type: List[Tensor]
|
|
131
|
+
assert values_flat
|
|
132
|
+
for i, v in enumerate(values_flat):
|
|
133
|
+
assert isinstance(v, Tensor), f"unexpected {true_value!r}, only expects tensors, got {type(v)}"
|
|
134
|
+
if i == 0:
|
|
135
|
+
values_flat[i] = _utils.copy(v, name=self.true_branch_name_ctx.get_child("output"))
|
|
136
|
+
else:
|
|
137
|
+
values_flat[i] = _utils.mark_as_output_in_scope(v, scope=self.true_branch_name_ctx)
|
|
138
|
+
true_value = nest.pack_sequence_as(true_value, values_flat)
|
|
139
|
+
self.true_branch_name_ctx.__exit__(None, None, None)
|
|
140
|
+
self.false_branch_name_ctx.__enter__()
|
|
141
|
+
self._true_value = true_value
|
|
142
|
+
self._entered_state = False
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def false(self) -> T:
|
|
146
|
+
"""
|
|
147
|
+
The getter usually would not be used.
|
|
148
|
+
"""
|
|
149
|
+
return self._false_value
|
|
150
|
+
|
|
151
|
+
@false.setter
|
|
152
|
+
def false(self, false_value: T):
|
|
153
|
+
"""
|
|
154
|
+
Define the False branch value.
|
|
155
|
+
After this, self.result is available.
|
|
156
|
+
"""
|
|
157
|
+
assert self._entered, f"{self} you need to be in the context scope"
|
|
158
|
+
assert (
|
|
159
|
+
self._entered_state is False
|
|
160
|
+
), f"{self} you need to be in the False branch, have assigned :func:`true` before"
|
|
161
|
+
assert false_value is not None
|
|
162
|
+
assert self._false_value is None
|
|
163
|
+
nest.assert_same_structure(self._true_value, false_value)
|
|
164
|
+
# This needs to match the true() setter logic.
|
|
165
|
+
if isinstance(false_value, Tensor):
|
|
166
|
+
false_value = _utils.copy(false_value, name=self.false_branch_name_ctx.get_child("output"))
|
|
167
|
+
else:
|
|
168
|
+
true_values_flat = nest.flatten(self._true_value) # type: List[Tensor]
|
|
169
|
+
false_values_flat = nest.flatten(false_value) # type: List[Tensor]
|
|
170
|
+
assert false_values_flat and len(false_values_flat) == len(true_values_flat)
|
|
171
|
+
for i, (true_v, false_v) in enumerate(zip(true_values_flat, false_values_flat)):
|
|
172
|
+
assert isinstance(true_v, Tensor)
|
|
173
|
+
assert isinstance(
|
|
174
|
+
false_v, Tensor
|
|
175
|
+
), f"unexpected {false_value!r}, only expects tensors, got {type(false_v)}"
|
|
176
|
+
assert true_v.raw_tensor.parent is self.true_branch_name_ctx
|
|
177
|
+
name = true_v.raw_tensor.name
|
|
178
|
+
false_values_flat[i] = _utils.copy(false_v, name=self.false_branch_name_ctx.get_child(name))
|
|
179
|
+
if name != "output":
|
|
180
|
+
false_values_flat[i].raw_tensor.layer_dict["is_output_layer"] = True
|
|
181
|
+
false_value = nest.pack_sequence_as(false_value, false_values_flat)
|
|
182
|
+
self.false_branch_name_ctx.__exit__(None, None, None)
|
|
183
|
+
self._false_value = false_value
|
|
184
|
+
self._result_value = self.layer_module()
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def result(self) -> T:
|
|
188
|
+
"""
|
|
189
|
+
:return: the result, after you assigned :func:`true` and :func:`false`.
|
|
190
|
+
"""
|
|
191
|
+
assert self._true_value is not None, f"{self} you need to have defined the true value"
|
|
192
|
+
assert self._false_value is not None, f"{self} you need to have defined the false value"
|
|
193
|
+
assert self._result_value is not None
|
|
194
|
+
return self._result_value
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class CondModule(rf.Module):
|
|
198
|
+
"""
|
|
199
|
+
This module is used internally by :class:`Cond` to create the RETURNN :class:`CondLayer` for the conditional code.
|
|
200
|
+
This module would not be directly used by the user.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
def __init__(self, cond: Cond):
|
|
204
|
+
super(CondModule, self).__init__()
|
|
205
|
+
self.cond = cond
|
|
206
|
+
|
|
207
|
+
def __call__(self):
|
|
208
|
+
"""
|
|
209
|
+
Makes layer dict for this loop, i.e. a RecLayer.
|
|
210
|
+
|
|
211
|
+
:return: structure like true_value/false_value
|
|
212
|
+
"""
|
|
213
|
+
name_ctx = self.cond.name_ctx
|
|
214
|
+
# noinspection PyProtectedMember
|
|
215
|
+
true_value, false_value = self.cond._true_value, self.cond._false_value
|
|
216
|
+
true_values_flat = nest.flatten(true_value) # type: List[Tensor]
|
|
217
|
+
false_values_flat = nest.flatten(false_value) # type: List[Tensor]
|
|
218
|
+
assert len(true_values_flat) == len(false_values_flat)
|
|
219
|
+
res = rfl.make_layer(
|
|
220
|
+
{
|
|
221
|
+
"class": "cond",
|
|
222
|
+
"from": [],
|
|
223
|
+
"condition": self.cond.condition,
|
|
224
|
+
"true_layer": {
|
|
225
|
+
"class": "subnetwork",
|
|
226
|
+
"from": [],
|
|
227
|
+
"subnetwork": self.cond.true_branch_name_ctx.make_net(),
|
|
228
|
+
},
|
|
229
|
+
"false_layer": {
|
|
230
|
+
"class": "subnetwork",
|
|
231
|
+
"from": [],
|
|
232
|
+
"subnetwork": self.cond.false_branch_name_ctx.make_net(),
|
|
233
|
+
},
|
|
234
|
+
},
|
|
235
|
+
name=name_ctx,
|
|
236
|
+
predefined_out_data=true_values_flat[0].copy_template(),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
results = []
|
|
240
|
+
for i, (true_v, false_v) in enumerate(zip(true_values_flat, false_values_flat)):
|
|
241
|
+
assert isinstance(true_v, Tensor) and isinstance(false_v, Tensor)
|
|
242
|
+
assert true_v.raw_tensor.parent is self.cond.true_branch_name_ctx
|
|
243
|
+
name = true_v.raw_tensor.name
|
|
244
|
+
if i == 0:
|
|
245
|
+
results.append(res)
|
|
246
|
+
else:
|
|
247
|
+
# noinspection PyProtectedMember
|
|
248
|
+
results.append(rfl._get_sub_layer(res, name, data=true_v.copy_template()))
|
|
249
|
+
results[-1].raw_tensor.layer_extra_dependencies.extend(
|
|
250
|
+
(self.cond.condition.raw_tensor, true_v.raw_tensor, false_v.raw_tensor)
|
|
251
|
+
)
|
|
252
|
+
return nest.pack_sequence_as(true_value, results)
|
|
@@ -309,6 +309,7 @@ class _ConcatInputLayer(LayerBase):
|
|
|
309
309
|
class CopyLayer(_ConcatInputLayer):
|
|
310
310
|
"""
|
|
311
311
|
This layer does nothing, it copies its input.
|
|
312
|
+
This is not even a ``tf.identity``. It refers to the same TF tensor.
|
|
312
313
|
If multiple sources are provided, they are concatenated in the feature-dim.
|
|
313
314
|
"""
|
|
314
315
|
|
|
@@ -323,6 +324,9 @@ class CopyLayer(_ConcatInputLayer):
|
|
|
323
324
|
We only have this here for the :class:`CopyLayer` because the :func:`get_out_data_from_opts`
|
|
324
325
|
must know about it and define the right beam.
|
|
325
326
|
Also see the option ``collocate_with``, which is different in that it does *not* add a dependency.
|
|
327
|
+
Note that this will not be real TF control dependencies,
|
|
328
|
+
but it simply sets the dependency on the layer.
|
|
329
|
+
If you want to have a real TF control dependency, use :class:`IdentityLayer`.
|
|
326
330
|
"""
|
|
327
331
|
if in_dim and out_dim:
|
|
328
332
|
assert in_dim == out_dim
|
|
@@ -415,6 +419,42 @@ class CopyLayer(_ConcatInputLayer):
|
|
|
415
419
|
d["extra_deps"] = [get_layer(src_name) for src_name in extra_deps]
|
|
416
420
|
|
|
417
421
|
|
|
422
|
+
class IdentityLayer(LayerBase):
|
|
423
|
+
"""
|
|
424
|
+
Wraps ``tf.identity`` with potential control dependencies.
|
|
425
|
+
|
|
426
|
+
The difference to :class:`CopyLayer` is that this creates a new TF op (``tf.identity``),
|
|
427
|
+
which allows for potential control dependencies.
|
|
428
|
+
This is the whole purpose of this layer.
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
layer_class = "identity"
|
|
432
|
+
|
|
433
|
+
def __init__(self, sources: List[LayerBase], control_dependencies: Sequence[LayerBase], **kwargs):
|
|
434
|
+
super().__init__(sources=sources, **kwargs)
|
|
435
|
+
assert len(sources) == 1
|
|
436
|
+
self.control_dependencies = control_dependencies
|
|
437
|
+
with tf.control_dependencies([src.output.placeholder.op for src in self.control_dependencies]):
|
|
438
|
+
self.output.placeholder = tf.identity(self.sources[0].output.placeholder)
|
|
439
|
+
|
|
440
|
+
def get_dep_layers(self) -> List[LayerBase]:
|
|
441
|
+
"""deps"""
|
|
442
|
+
return super().get_dep_layers() + list(self.control_dependencies)
|
|
443
|
+
|
|
444
|
+
@classmethod
|
|
445
|
+
def get_out_data_from_opts(cls, name: str, sources: List[LayerBase], **kwargs):
|
|
446
|
+
"""out"""
|
|
447
|
+
assert sources
|
|
448
|
+
return sources[0].output.copy(name="%s_output" % name)
|
|
449
|
+
|
|
450
|
+
@classmethod
|
|
451
|
+
def transform_config_dict(cls, d, network, get_layer):
|
|
452
|
+
"""transform"""
|
|
453
|
+
super().transform_config_dict(d, network=network, get_layer=get_layer)
|
|
454
|
+
assert isinstance(d.get("control_dependencies"), (list, tuple))
|
|
455
|
+
d["control_dependencies"] = [get_layer(src_name) for src_name in d["control_dependencies"]]
|
|
456
|
+
|
|
457
|
+
|
|
418
458
|
class ConcatLayer(LayerBase):
|
|
419
459
|
"""
|
|
420
460
|
Concatenates the inputs in specified axes.
|
|
@@ -5850,11 +5850,13 @@ def same_control_flow_ctx(x):
|
|
|
5850
5850
|
with tf.control_dependencies(None) as dep: # this will reset the context
|
|
5851
5851
|
yield dep
|
|
5852
5852
|
return
|
|
5853
|
-
|
|
5854
|
-
|
|
5855
|
-
|
|
5856
|
-
|
|
5857
|
-
|
|
5853
|
+
try:
|
|
5854
|
+
# noinspection PyProtectedMember
|
|
5855
|
+
graph._set_control_flow_context(ctx)
|
|
5856
|
+
yield ctx
|
|
5857
|
+
finally:
|
|
5858
|
+
# noinspection PyProtectedMember
|
|
5859
|
+
graph._set_control_flow_context(cur_ctx)
|
|
5858
5860
|
|
|
5859
5861
|
|
|
5860
5862
|
def get_protobuf_fields(obj):
|
|
@@ -483,8 +483,12 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
483
483
|
:return: parameter
|
|
484
484
|
"""
|
|
485
485
|
assert all(d.is_static() for d in tensor.dims)
|
|
486
|
-
data = torch.zeros(
|
|
487
|
-
|
|
486
|
+
data = torch.zeros([d.dimension for d in tensor.dims], dtype=TorchBackend.as_dtype_raw(tensor.dtype))
|
|
487
|
+
if tensor.dtype.startswith("int"):
|
|
488
|
+
requires_grad = False
|
|
489
|
+
else:
|
|
490
|
+
requires_grad = True
|
|
491
|
+
return torch.nn.Parameter(data, requires_grad=requires_grad)
|
|
488
492
|
|
|
489
493
|
@staticmethod
|
|
490
494
|
def set_parameter_initial_value(param: rf.Parameter, value: Union[None, Tensor, rf.RawTensorTypes]) -> None:
|
|
@@ -498,7 +502,8 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
498
502
|
assert isinstance(raw_param, torch.nn.Parameter)
|
|
499
503
|
with torch.no_grad():
|
|
500
504
|
if isinstance(value, Tensor):
|
|
501
|
-
|
|
505
|
+
value_ = value.copy_compatible_to(param)
|
|
506
|
+
raw_param.copy_(value_.raw_tensor)
|
|
502
507
|
elif isinstance(value, numpy.ndarray):
|
|
503
508
|
raw_param.copy_(torch.from_numpy(value))
|
|
504
509
|
else:
|
|
@@ -220,6 +220,7 @@ returnn/tf/updater.py
|
|
|
220
220
|
returnn/tf/frontend_layers/__init__.py
|
|
221
221
|
returnn/tf/frontend_layers/_backend.py
|
|
222
222
|
returnn/tf/frontend_layers/_utils.py
|
|
223
|
+
returnn/tf/frontend_layers/cond.py
|
|
223
224
|
returnn/tf/frontend_layers/config_entry_points.py
|
|
224
225
|
returnn/tf/frontend_layers/debug_eager_mode.py
|
|
225
226
|
returnn/tf/frontend_layers/dims.py
|
|
@@ -314,6 +315,7 @@ tests/test_hdf_dump.py
|
|
|
314
315
|
tests/test_rf_array.py
|
|
315
316
|
tests/test_rf_attention.py
|
|
316
317
|
tests/test_rf_base.py
|
|
318
|
+
tests/test_rf_cond.py
|
|
317
319
|
tests/test_rf_container.py
|
|
318
320
|
tests/test_rf_conv.py
|
|
319
321
|
tests/test_rf_encoder_conformer.py
|