returnn 1.20230408.155406__tar.gz → 1.20230409.122444__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {returnn-1.20230408.155406/returnn.egg-info → returnn-1.20230409.122444}/PKG-INFO +1 -1
- returnn-1.20230409.122444/_setup_info_generated.py +2 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-rf.config +1 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-torch.config +1 -1
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/_backend.py +31 -1
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/array_.py +11 -1
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/const.py +3 -3
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/run_ctx.py +29 -14
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/_dim_extra.py +35 -8
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/_tensor_extra.py +39 -1
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/_backend.py +5 -5
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_low_level/_backend.py +12 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/engine.py +50 -29
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/frontend/_backend.py +38 -8
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/frontend/bridge.py +3 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444/returnn.egg-info}/PKG-INFO +1 -1
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_demos.py +7 -0
- returnn-1.20230408.155406/_setup_info_generated.py +0 -2
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/.editorconfig +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/.gitignore +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/.gitmodules +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/.kateconfig +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/CHANGELOG.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/CODEOWNERS +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/CONTRIBUTING.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/LICENSE +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/MANIFEST.in +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/README.rst +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/12AX.cluster_map +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/_setup_returnn_env.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-fwd.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-horovod-mpi.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-horovod-mpi.py.sh +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-horovod-mpi.sh +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-hyper-param-tuning.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-iter-dataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-list-devices.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-lua-torch-layer.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-pretrain.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-record-and-push-to-webserver.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-returnn-as-framework.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-rhn-enwik8.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-sprint-interface.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-att-copy.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-attention.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-enc-dec.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-hard-att-copy.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-lstm-benchmark.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-native-lstm.12ax.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-native-lstm2.12ax.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-neural-transducer.12ax.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-rec-explicit-lstm.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-rec-explicit-rnn.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-rec-self-att.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-search-compiled-graph.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-timit-lstm-ctc.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo.sh +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/README.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/chars.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/config_demo +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/config_fwd +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/config_real +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/decode.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/go.sh +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/lines.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/split/eval.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/split/train.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/split/valid.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/README.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial/create_test_h5.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial/forwardconfig +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial/go.sh +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial/trainconfig +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial_rgb/go.sh +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/pyproject.toml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/requirements.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/__main__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/__old_mod_loader__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/__setup__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/config.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/audio.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/basic.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/bundle_file.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/cached.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/cached2.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/generating.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/hdf.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/lm.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/map.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/meta.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/multi_proc.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/normalization_data.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/numpy_dump.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/raw_wav.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/sprint.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/stereo.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/util/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/util/feature_extraction.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/util/vocabulary.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/engine/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/engine/base.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/engine/batch.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/__main__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/.git +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/README.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/edit.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/reroute.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/select.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/subgraph.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/transform.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/util.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/_numpy_backend.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/_utils.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/dims.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/dtype.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/init.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/linear.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/loss.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/math_.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/matmul.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/module.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/parameter.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/rand.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/reduce.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/state.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/types.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/import_/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/import_/common.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/import_/git.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/import_/import_.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/learning_rate_control.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/log.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/native_op.cpp +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/native_op.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/pretrain.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/cache.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/control.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/error_signals.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/extern_interface.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/interface.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/README.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/_tensor_mixin_base.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/_tensor_op_overloads.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/control_flow_ctx.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/dim.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/marked_dim.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/tensor.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/tensor_dict.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/compat.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/data_pipeline.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/distributed.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/engine.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/_utils.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/dims.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/layer.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/make_layer.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_low_level/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/horovod.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/hyper_param_tuning.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/base.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/basic.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/rec.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/segmental_model.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/signal_processing.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/native_op.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/network.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/sprint.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/updater.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/basic.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/data.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/ken_lm.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/open_fst.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/README.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/data/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/data/pipeline.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/data/tensor_utils.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/frontend/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/frontend/_rand.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/functional/README.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/functional/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/updater.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/__init__.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/basic.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/better_exchook.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/bpe.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/debug.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/debug_helpers.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/fsa.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/literal_py_to_pickle.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/pprint.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/py-to-pickle.cpp +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/sig_proc.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/task_system.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn.egg-info/SOURCES.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn.egg-info/dependency_links.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn.egg-info/top_level.txt +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/rnn.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/setup.cfg +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/setup.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/DummySprintExec.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm-inspection-profile.xml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/.gitignore +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/.name +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/misc.xml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/modules.xml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/returnn.iml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/_set_num_threads1.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/_setup_returnn_env.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/_setup_test_env.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/bpe-unicode-demo.codes +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/bpe-unicode-demo.vocab +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lexicon_opt.fst +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lexicon_opt.isyms +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lexicon_opt.jpg +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lexicon_opt.osyms +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lint_common.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/pycharm-inspect.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/pylint.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/returnn-as-framework.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/rf_utils.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/spelling.dic +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Config.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Dataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Fsa.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_GeneratingDataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_HDFDataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_LearningRateControl.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Log.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_MultiProcDataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_PTDataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Pretrain.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_ResNet.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_SprintDataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_SprintInterface.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFEngine.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFNativeOp.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFNetworkLayer.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFNetworkRecLayer.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFNetworkSigProcLayer.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFUpdater.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFUtil.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TF_determinism.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TaskSystem.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TaskSystem_SharedMem.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TranslationDataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Util.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_fork_exec.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_hdf_dump.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_rf_base.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_tensor.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_tools.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_torch_frontend.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_torch_internal_frontend.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/_setup_returnn_env.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/analyze-dataset-batches.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bliss-collect-seq-lens.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bliss-dump-text.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bliss-get-segment-names.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bliss-to-ogg-zip.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bpe-create-lexicon.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/calculate-word-error-rate.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/cleanup-old-models.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/collect-orth-symbols.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/collect-words.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/compile_native_op.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/compile_tf_graph.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/debug-dump-search-scores.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/debug-plot-search-scores.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-dataset-raw-strings.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-dataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-forward-stats.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-forward.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-network-json.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-pickle.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/extract_state_tying_from_dataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/get-attention-weights.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/get-best-model-epoch.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/hdf_dump.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/hdf_dump_translation_dataset.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/import-blocks-mt-model.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/import-t2t-mt-model.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/.gitignore +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/Makefile +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/README.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/README.md +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/libs_list +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/state_vars_list +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/tensor_names_list +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/file.h +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/main.cc +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/rescorer.h +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/vocabulary.cc +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/vocabulary.h +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/tf_avg_checkpoints.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/tf_inspect_checkpoint.py +0 -0
- {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/tf_inspect_summary_log.py +0 -0
|
@@ -50,6 +50,7 @@ def train_step(*, model: Model, extern_data, **_kwargs):
|
|
|
50
50
|
data = extern_data["data"]
|
|
51
51
|
logits = model(data)
|
|
52
52
|
targets = extern_data["classes"]
|
|
53
|
+
# TODO: use flattening on logits/targets
|
|
53
54
|
loss = rf.cross_entropy(estimated=logits, estimated_type="logits", target=targets, axis=out_dim)
|
|
54
55
|
loss.mark_as_loss(name="ce")
|
|
55
56
|
|
|
@@ -58,7 +58,7 @@ def train_step(*, model: Model, extern_data, **_kwargs):
|
|
|
58
58
|
targets = extern_data["classes"]
|
|
59
59
|
targets_packed = torch.nn.utils.rnn.pack_padded_sequence(
|
|
60
60
|
targets.raw_tensor, data.dims[1].dyn_size_ext.raw_tensor, batch_first=True, enforce_sorted=False)
|
|
61
|
-
loss = nn.CrossEntropyLoss()(logits_packed.data, targets_packed.data.long())
|
|
61
|
+
loss = nn.CrossEntropyLoss(reduction='none')(logits_packed.data, targets_packed.data.long())
|
|
62
62
|
rf.get_run_ctx().mark_as_loss(name="cross_entropy", loss=loss)
|
|
63
63
|
|
|
64
64
|
|
|
@@ -208,6 +208,29 @@ class Backend(Generic[T]):
|
|
|
208
208
|
"""
|
|
209
209
|
raise NotImplementedError
|
|
210
210
|
|
|
211
|
+
@staticmethod
|
|
212
|
+
def cast_raw(raw_tensor: T, dtype: str) -> T:
|
|
213
|
+
"""
|
|
214
|
+
:param raw_tensor:
|
|
215
|
+
:param dtype: e.g. "float32"
|
|
216
|
+
:return: raw tensor with dtype casted
|
|
217
|
+
"""
|
|
218
|
+
raise NotImplementedError
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def cast(tensor: Tensor, dtype: str) -> Tensor:
|
|
222
|
+
"""
|
|
223
|
+
:param tensor:
|
|
224
|
+
:param dtype: e.g. "float32"
|
|
225
|
+
:return: tensor with dtype casted
|
|
226
|
+
"""
|
|
227
|
+
# Default implementation using cast_raw.
|
|
228
|
+
res = tensor.copy_template()
|
|
229
|
+
res.dtype = dtype
|
|
230
|
+
# noinspection PyProtectedMember
|
|
231
|
+
res.raw_tensor = tensor._raw_backend.cast_raw(tensor.raw_tensor, dtype)
|
|
232
|
+
return res
|
|
233
|
+
|
|
211
234
|
# Restrict the possible activation function names,
|
|
212
235
|
# to not get unexpected behavior,
|
|
213
236
|
# or unwanted incompatibilities.
|
|
@@ -287,6 +310,13 @@ class Backend(Generic[T]):
|
|
|
287
310
|
"""
|
|
288
311
|
raise NotImplementedError
|
|
289
312
|
|
|
313
|
+
@staticmethod
|
|
314
|
+
def have_sequence_mask_raw() -> bool:
|
|
315
|
+
"""
|
|
316
|
+
:return: whether we have a sequence_mask_raw implementation
|
|
317
|
+
"""
|
|
318
|
+
return False
|
|
319
|
+
|
|
290
320
|
@staticmethod
|
|
291
321
|
def sequence_mask_raw(lengths: T, *, batch_major: bool = True) -> T:
|
|
292
322
|
"""
|
|
@@ -309,7 +339,7 @@ class Backend(Generic[T]):
|
|
|
309
339
|
:return: context manager
|
|
310
340
|
"""
|
|
311
341
|
# Default implementation for eager-based frameworks
|
|
312
|
-
|
|
342
|
+
yield # nothing to do
|
|
313
343
|
|
|
314
344
|
@staticmethod
|
|
315
345
|
@contextlib.contextmanager
|
|
@@ -12,7 +12,7 @@ from .types import RawTensorTypes
|
|
|
12
12
|
|
|
13
13
|
T = TypeVar("T")
|
|
14
14
|
|
|
15
|
-
__all__ = ["convert_to_tensor", "constant", "gather"]
|
|
15
|
+
__all__ = ["convert_to_tensor", "constant", "cast", "gather"]
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def convert_to_tensor(
|
|
@@ -77,6 +77,16 @@ def convert_to_tensor(
|
|
|
77
77
|
constant = convert_to_tensor # alias for some older code
|
|
78
78
|
|
|
79
79
|
|
|
80
|
+
def cast(tensor: Tensor, dtype: str) -> Tensor:
|
|
81
|
+
"""
|
|
82
|
+
:param tensor:
|
|
83
|
+
:param dtype:
|
|
84
|
+
:return: tensor with the same data, but with a different dtype
|
|
85
|
+
"""
|
|
86
|
+
# noinspection PyProtectedMember
|
|
87
|
+
return tensor._raw_backend.cast(tensor, dtype=dtype)
|
|
88
|
+
|
|
89
|
+
|
|
80
90
|
# noinspection PyUnusedLocal
|
|
81
91
|
def gather(
|
|
82
92
|
source: Tensor,
|
|
@@ -14,7 +14,7 @@ __all__ = ["full", "constant", "fill", "zeros", "ones"]
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def full(
|
|
17
|
-
dims: Sequence[Dim], fill_value: RawTensorTypes,
|
|
17
|
+
*, dims: Sequence[Dim], fill_value: RawTensorTypes, dtype: Optional[str] = None, sparse_dim: Optional[Dim] = None
|
|
18
18
|
) -> Tensor:
|
|
19
19
|
"""
|
|
20
20
|
full
|
|
@@ -46,11 +46,11 @@ def zeros(dims: Sequence[Dim], *, dtype: Optional[str] = None, sparse_dim: Optio
|
|
|
46
46
|
"""
|
|
47
47
|
zeros. float by default.
|
|
48
48
|
"""
|
|
49
|
-
return full(dims, 0, dtype=dtype or rf.get_default_float_dtype(), sparse_dim=sparse_dim)
|
|
49
|
+
return full(dims=dims, fill_value=0, dtype=dtype or rf.get_default_float_dtype(), sparse_dim=sparse_dim)
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
def ones(dims: Sequence[Dim], *, dtype: Optional[str] = None, sparse_dim: Optional[Dim] = None) -> Tensor:
|
|
53
53
|
"""
|
|
54
54
|
ones. float by default.
|
|
55
55
|
"""
|
|
56
|
-
return full(dims, 1, dtype=dtype or rf.get_default_float_dtype(), sparse_dim=sparse_dim)
|
|
56
|
+
return full(dims=dims, fill_value=1, dtype=dtype or rf.get_default_float_dtype(), sparse_dim=sparse_dim)
|
|
@@ -119,24 +119,18 @@ class RunCtx:
|
|
|
119
119
|
E.g. if the overall normalization is sum(loss)/sum(num_frames), this is also what the optimizer will use,
|
|
120
120
|
otherwise the optimizer will just use sum(loss).
|
|
121
121
|
:param custom_inv_norm_factor:
|
|
122
|
-
The standard norm factor is
|
|
123
|
-
or
|
|
122
|
+
The standard inv norm factor is sum(target_seq_len) if the target has a time-axis,
|
|
123
|
+
or sum(output_seq_len) if there is no target and the output has a time-axis,
|
|
124
124
|
or 1 otherwise. (See :func:`Loss.init` for details.)
|
|
125
125
|
This is used for proper normalization of accumulated loss/error per epoch
|
|
126
126
|
and also proper normalization per batch for reporting,
|
|
127
127
|
no matter if use_normalized_loss is True or False.
|
|
128
128
|
If you want to change this norm factor, you can set this.
|
|
129
|
-
Basically, for all reporting, it uses sum(loss)
|
|
129
|
+
Basically, for all reporting, it uses sum(loss) / sum(custom_inv_norm_factor).
|
|
130
130
|
"""
|
|
131
131
|
assert self.stage == "train_step"
|
|
132
132
|
if not isinstance(loss, Tensor):
|
|
133
133
|
assert isinstance(loss, _backend.global_backend.RawTensorType)
|
|
134
|
-
assert _backend.global_backend.get_ndim_raw(loss) == 0, (
|
|
135
|
-
f"mark_as_loss(<loss with shape {_backend.global_backend.get_known_shape_raw(loss)}>, {name!r}):"
|
|
136
|
-
" Only scalar raw losses are supported,"
|
|
137
|
-
" because we cannot know whether there are any dynamic dims which might require padding."
|
|
138
|
-
" Explicitly convert to a Tensor first and specify dim tags."
|
|
139
|
-
)
|
|
140
134
|
loss = rf.convert_to_tensor(loss)
|
|
141
135
|
assert name not in self.losses
|
|
142
136
|
self.losses[name] = Loss(
|
|
@@ -220,31 +214,52 @@ class Loss:
|
|
|
220
214
|
|
|
221
215
|
scale: float = 1.0
|
|
222
216
|
as_error: bool = False
|
|
223
|
-
use_normalized_loss: bool = False
|
|
217
|
+
use_normalized_loss: bool = False # for the gradient / total loss
|
|
224
218
|
use_flatten_frames: bool = True
|
|
225
219
|
custom_inv_norm_factor: Optional[Tensor] = None
|
|
226
220
|
|
|
221
|
+
_summed_loss_cached: Optional[Tensor] = None
|
|
222
|
+
_mean_loss_cached: Optional[Tensor] = None
|
|
223
|
+
|
|
227
224
|
def get_summed_loss(self) -> Tensor:
|
|
228
225
|
"""
|
|
229
226
|
:return: sum of loss (scalar)
|
|
230
227
|
"""
|
|
231
228
|
if not self.loss.dims:
|
|
232
229
|
return self.loss
|
|
233
|
-
|
|
230
|
+
if self._summed_loss_cached is not None:
|
|
231
|
+
return self._summed_loss_cached
|
|
232
|
+
if self._mean_loss_cached is not None:
|
|
233
|
+
return self._mean_loss_cached / self.get_inv_norm_factor()
|
|
234
|
+
self._summed_loss_cached = rf.reduce_sum(self.loss, axis=self.loss.dims)
|
|
235
|
+
return self._summed_loss_cached
|
|
234
236
|
|
|
235
237
|
def get_mean_loss(self) -> Tensor:
|
|
236
238
|
"""
|
|
237
239
|
:return: sum of loss (scalar)
|
|
238
240
|
"""
|
|
241
|
+
if self._mean_loss_cached is not None:
|
|
242
|
+
return self._mean_loss_cached
|
|
239
243
|
if self.custom_inv_norm_factor:
|
|
240
|
-
|
|
244
|
+
loss = self.get_summed_loss()
|
|
245
|
+
loss /= rf.cast(self.custom_inv_norm_factor, dtype=loss.dtype)
|
|
246
|
+
return loss
|
|
241
247
|
if not self.loss.dims:
|
|
242
248
|
return self.loss
|
|
243
|
-
|
|
249
|
+
self._mean_loss_cached = rf.reduce_mean(self.loss, axis=self.loss.dims)
|
|
250
|
+
return self._mean_loss_cached
|
|
251
|
+
|
|
252
|
+
def get_inv_norm_factor(self) -> Union[int, Tensor]:
|
|
253
|
+
"""
|
|
254
|
+
:return: inverse norm factor (scalar)
|
|
255
|
+
"""
|
|
256
|
+
if self.custom_inv_norm_factor:
|
|
257
|
+
return self.custom_inv_norm_factor
|
|
258
|
+
return self.loss.num_elements()
|
|
244
259
|
|
|
245
260
|
def get_scaled_reduced_loss(self) -> Tensor:
|
|
246
261
|
"""
|
|
247
|
-
:return: scaled reduced loss (scalar), as it is supposed to be used for calculating the
|
|
262
|
+
:return: scaled reduced loss (scalar), as it is supposed to be used for calculating the train gradient
|
|
248
263
|
"""
|
|
249
264
|
if self.use_normalized_loss:
|
|
250
265
|
loss = self.get_mean_loss()
|
|
@@ -747,7 +747,11 @@ class _DimMixin:
|
|
|
747
747
|
:return: whether dim is static or dynamic but with scalar dyn_size_ext
|
|
748
748
|
"""
|
|
749
749
|
if self.is_static():
|
|
750
|
+
if self.capacity is not None:
|
|
751
|
+
return self.size < self.capacity
|
|
750
752
|
return False
|
|
753
|
+
if self.capacity is not None:
|
|
754
|
+
return True
|
|
751
755
|
if not self.dyn_size_ext:
|
|
752
756
|
return True # unknown
|
|
753
757
|
return self.dyn_size_ext.batch_ndim > 0
|
|
@@ -1516,6 +1520,21 @@ class _DimMixin:
|
|
|
1516
1520
|
If `self.src_data` has a placeholder, will use the shape from there.
|
|
1517
1521
|
Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
|
|
1518
1522
|
|
|
1523
|
+
:return: max(size or dyn_size)
|
|
1524
|
+
"""
|
|
1525
|
+
res = self.get_dim_value_tensor()
|
|
1526
|
+
if isinstance(res, _t.Tensor):
|
|
1527
|
+
assert res.dims == ()
|
|
1528
|
+
return res.raw_tensor
|
|
1529
|
+
assert isinstance(res, int)
|
|
1530
|
+
return res
|
|
1531
|
+
|
|
1532
|
+
def get_dim_value_tensor(self) -> Union[int, _t.Tensor]:
|
|
1533
|
+
"""
|
|
1534
|
+
Infers the dim this axis should have if unbroadcasted.
|
|
1535
|
+
If `self.src_data` has a placeholder, will use the shape from there.
|
|
1536
|
+
Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
|
|
1537
|
+
|
|
1519
1538
|
:return: max(size or dyn_size)
|
|
1520
1539
|
"""
|
|
1521
1540
|
import returnn.frontend as rf
|
|
@@ -1530,25 +1549,33 @@ class _DimMixin:
|
|
|
1530
1549
|
# Masking is not always possible here, e.g.
|
|
1531
1550
|
# self = Dim{'self-att-keys'['time:var:extern_data:classes'[B]]}.
|
|
1532
1551
|
use_time_mask=False,
|
|
1533
|
-
)
|
|
1534
|
-
return self.dyn_size_ext
|
|
1552
|
+
)
|
|
1553
|
+
return self.dyn_size_ext
|
|
1535
1554
|
if self.is_batch_dim():
|
|
1555
|
+
res = None
|
|
1536
1556
|
if self._extra and self._extra.src_data:
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1557
|
+
res = self._extra.src_data.get_batch_dim()
|
|
1558
|
+
elif self.batch:
|
|
1559
|
+
res = self.batch.dim
|
|
1560
|
+
if isinstance(res, int):
|
|
1561
|
+
return res
|
|
1562
|
+
if res is not None:
|
|
1563
|
+
return _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res)
|
|
1540
1564
|
if (
|
|
1541
1565
|
self._extra
|
|
1542
1566
|
and self._extra.src_data is not None
|
|
1543
1567
|
and self._extra.src_axis is not None
|
|
1544
1568
|
and self._extra.src_data.placeholder is not None
|
|
1545
1569
|
):
|
|
1546
|
-
|
|
1570
|
+
res = self._extra.src_data.get_dim(self._extra.src_axis)
|
|
1571
|
+
if isinstance(res, int):
|
|
1572
|
+
return res
|
|
1573
|
+
return _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res)
|
|
1547
1574
|
self.complete_dyn_size()
|
|
1548
1575
|
if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None:
|
|
1549
1576
|
if self.dyn_size_ext.batch_ndim > 0:
|
|
1550
|
-
return rf.reduce_max(self.dyn_size_ext, axis=self.dyn_size_ext.dim_tags)
|
|
1551
|
-
return self.dyn_size_ext
|
|
1577
|
+
return rf.reduce_max(self.dyn_size_ext, axis=self.dyn_size_ext.dim_tags)
|
|
1578
|
+
return self.dyn_size_ext
|
|
1552
1579
|
raise Exception("%s: need placeholder, self.dimension or self.dyn_size for dim value" % self)
|
|
1553
1580
|
|
|
1554
1581
|
def axis_split_info(self):
|
|
@@ -2676,7 +2676,11 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
2676
2676
|
backend = tag.dyn_size_ext._raw_backend
|
|
2677
2677
|
assert set(tag.dyn_size_ext.dim_tags).issubset(self.dim_tags) # https://github.com/rwth-i6/returnn/issues/721
|
|
2678
2678
|
with backend.name_scope_raw("get_sequence_mask_broadcast"):
|
|
2679
|
-
if
|
|
2679
|
+
if (
|
|
2680
|
+
backend.have_sequence_mask_raw()
|
|
2681
|
+
and tag.dyn_size_ext.have_batch_axis()
|
|
2682
|
+
and tag.dyn_size_ext.batch_ndim == 1
|
|
2683
|
+
): # just [B]
|
|
2680
2684
|
# This is the common case where the size is of shape [B].
|
|
2681
2685
|
# We make use of sequence_mask or sequence_mask_time_major in that case,
|
|
2682
2686
|
# which is optimized by caching.
|
|
@@ -2733,11 +2737,45 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
2733
2737
|
assert tag.dyn_size_ext
|
|
2734
2738
|
return tag.dyn_size_ext.copy_compatible_to(self, check_dtype=False, check_sparse=False).placeholder
|
|
2735
2739
|
|
|
2740
|
+
def num_elements(self: Tensor) -> Union[int, Tensor]:
|
|
2741
|
+
"""
|
|
2742
|
+
:return: number of elements in this tensor, i.e. prod(self.shape)
|
|
2743
|
+
:rtype: tf.Tensor
|
|
2744
|
+
"""
|
|
2745
|
+
if all(dim.is_static() for dim in self.dims):
|
|
2746
|
+
n = 1
|
|
2747
|
+
for dim in self.dims:
|
|
2748
|
+
n *= dim.dimension
|
|
2749
|
+
return n
|
|
2750
|
+
|
|
2751
|
+
import returnn.frontend as rf
|
|
2752
|
+
|
|
2753
|
+
n = 1
|
|
2754
|
+
dims = list(self.dims)
|
|
2755
|
+
dims.sort(key=lambda dim: -dim.dyn_size_ext.batch_ndim if dim.dyn_size_ext else 0)
|
|
2756
|
+
while dims:
|
|
2757
|
+
dim = dims.pop(0)
|
|
2758
|
+
if dim.is_static():
|
|
2759
|
+
n *= dim.dimension
|
|
2760
|
+
continue
|
|
2761
|
+
# E.g. dyn_size_ext is shape [B], and self has shape [B,T].
|
|
2762
|
+
# Due to the sorting of dims above, dims will be [T,B], and we will first process T.
|
|
2763
|
+
# We want to sum over dyn_size_ext, but then we need to remove the other dims it covers.
|
|
2764
|
+
for dim_ in dim.dyn_size_ext.dims:
|
|
2765
|
+
assert dim_ in dims # num elements not really well-defined then
|
|
2766
|
+
assert not dim_.need_masking() # not implemented
|
|
2767
|
+
dims.remove(dim_)
|
|
2768
|
+
n_ = rf.reduce_sum(dim.dyn_size_ext, axis=dim.dyn_size_ext.dims)
|
|
2769
|
+
n *= n_
|
|
2770
|
+
return n
|
|
2771
|
+
|
|
2736
2772
|
def copy_masked(self: Tensor, mask_value) -> Tensor:
|
|
2737
2773
|
"""
|
|
2738
2774
|
:param float|int|tf.Tensor mask_value:
|
|
2739
2775
|
"""
|
|
2740
2776
|
assert self.placeholder is not None
|
|
2777
|
+
if not any(dim.need_masking() for dim in self.dims):
|
|
2778
|
+
return self.copy()
|
|
2741
2779
|
assert self._raw_backend.is_tensorflow # not implemented otherwise for now
|
|
2742
2780
|
from returnn.tf.util.basic import mask_dyn_seq_len_nd
|
|
2743
2781
|
|
{returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/_backend.py
RENAMED
|
@@ -122,6 +122,11 @@ class ReturnnLayersBackend(Backend[Layer]):
|
|
|
122
122
|
"""transpose_raw is a no-op in this backend"""
|
|
123
123
|
return raw_tensor
|
|
124
124
|
|
|
125
|
+
@staticmethod
|
|
126
|
+
def cast(tensor: Tensor, dtype: str) -> Tensor:
|
|
127
|
+
"""cast"""
|
|
128
|
+
return rfl.make_layer({"class": "cast", "from": tensor, "dtype": dtype}, name="cast")
|
|
129
|
+
|
|
125
130
|
@staticmethod
|
|
126
131
|
def activation(tensor: Tensor, func: str) -> Tensor:
|
|
127
132
|
"""activation"""
|
|
@@ -172,11 +177,6 @@ class ReturnnLayersBackend(Backend[Layer]):
|
|
|
172
177
|
log_probs = rf.log_softmax(logits, axis=axis)
|
|
173
178
|
return -rf.matmul(targets, log_probs, reduce=axis)
|
|
174
179
|
|
|
175
|
-
@staticmethod
|
|
176
|
-
def sequence_mask_raw(lengths: Layer, *, batch_major: bool = True) -> Layer:
|
|
177
|
-
"""sequence mask"""
|
|
178
|
-
raise NotImplementedError # TODO
|
|
179
|
-
|
|
180
180
|
@staticmethod
|
|
181
181
|
def create_parameter_raw(tensor: rf.Parameter) -> Layer:
|
|
182
182
|
"""create parameter"""
|
{returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_low_level/_backend.py
RENAMED
|
@@ -194,6 +194,11 @@ class TFBackend(Backend[tf.Tensor]):
|
|
|
194
194
|
with tf_util.same_control_flow_ctx(raw_tensor):
|
|
195
195
|
return tf.tile(raw_tensor, [1] * axis + [dim] + [1] * (raw_tensor.shape.ndims - axis - 1))
|
|
196
196
|
|
|
197
|
+
@staticmethod
|
|
198
|
+
def cast_raw(raw_tensor: tf.Tensor, dtype: str) -> tf.Tensor:
|
|
199
|
+
"""cast"""
|
|
200
|
+
return tf.cast(raw_tensor, dtype)
|
|
201
|
+
|
|
197
202
|
@staticmethod
|
|
198
203
|
def activation_raw(raw_tensor: tf.Tensor, func: str) -> tf.Tensor:
|
|
199
204
|
"""
|
|
@@ -212,6 +217,13 @@ class TFBackend(Backend[tf.Tensor]):
|
|
|
212
217
|
raise ValueError(f"unknown activation function {func!r}")
|
|
213
218
|
return f(raw_tensor)
|
|
214
219
|
|
|
220
|
+
@staticmethod
|
|
221
|
+
def have_sequence_mask_raw() -> bool:
|
|
222
|
+
"""
|
|
223
|
+
:return: whether we have sequence_mask
|
|
224
|
+
"""
|
|
225
|
+
return True
|
|
226
|
+
|
|
215
227
|
@staticmethod
|
|
216
228
|
def sequence_mask_raw(lengths: tf.Tensor, *, batch_major: bool = True) -> tf.Tensor:
|
|
217
229
|
"""
|
|
@@ -141,34 +141,44 @@ class Engine(EngineBase):
|
|
|
141
141
|
self._pt_model.train()
|
|
142
142
|
|
|
143
143
|
accumulated_losses_dict = NumbersDict()
|
|
144
|
+
accumulated_inv_norm_factors_dict = NumbersDict()
|
|
144
145
|
step_idx = 0
|
|
145
146
|
for data in self._train_dataloader:
|
|
146
147
|
self._run_step(data)
|
|
147
148
|
|
|
148
149
|
train_ctx = rf.get_run_ctx()
|
|
149
|
-
losses_dict = train_ctx.losses
|
|
150
150
|
total_loss = train_ctx.total_loss()
|
|
151
|
+
losses_dict = NumbersDict(
|
|
152
|
+
{
|
|
153
|
+
name: float(loss.get_summed_loss().raw_tensor.detach().cpu().numpy())
|
|
154
|
+
for name, loss in train_ctx.losses.items()
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
inv_norm_factors_dict = NumbersDict(
|
|
158
|
+
{name: float(_to_raw(loss.get_inv_norm_factor())) for name, loss in train_ctx.losses.items()}
|
|
159
|
+
)
|
|
151
160
|
|
|
152
161
|
self._updater.get_optimizer().zero_grad()
|
|
153
162
|
total_loss.raw_tensor.backward()
|
|
154
163
|
self._updater.get_optimizer().step()
|
|
155
164
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
}
|
|
160
|
-
accumulated_losses_dict += NumbersDict(losses_dict)
|
|
161
|
-
print("step %i, loss: %f" % (step_idx, total_loss.raw_tensor.detach().cpu().numpy()), file=log.v4)
|
|
165
|
+
accumulated_losses_dict += losses_dict
|
|
166
|
+
accumulated_inv_norm_factors_dict += inv_norm_factors_dict
|
|
167
|
+
print(f"step {step_idx}, loss: {dict(losses_dict / inv_norm_factors_dict)}", file=log.v4)
|
|
162
168
|
|
|
163
169
|
step_idx += 1
|
|
164
170
|
self._train_step += 1
|
|
165
171
|
|
|
166
172
|
print("Trained %i steps" % step_idx)
|
|
167
173
|
|
|
168
|
-
accumulated_losses_dict = accumulated_losses_dict /
|
|
169
|
-
self.learning_rate_control.set_epoch_error(
|
|
174
|
+
accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict
|
|
175
|
+
self.learning_rate_control.set_epoch_error(
|
|
176
|
+
self.epoch, {f"train_loss_{k}": v for k, v in accumulated_losses_dict.items()}
|
|
177
|
+
)
|
|
170
178
|
self.learning_rate_control.save()
|
|
171
179
|
|
|
180
|
+
print(f"Total train loss: {dict(accumulated_losses_dict)}", file=log.v3)
|
|
181
|
+
|
|
172
182
|
if self.epoch % self._save_model_epoch_interval == 0 or self.epoch == self._final_epoch:
|
|
173
183
|
self._save_model()
|
|
174
184
|
self._save_optimizer()
|
|
@@ -186,8 +196,8 @@ class Engine(EngineBase):
|
|
|
186
196
|
|
|
187
197
|
data_loader = self._eval_dataloaders[dataset_name]
|
|
188
198
|
|
|
189
|
-
accumulated_loss = 0.0
|
|
190
199
|
accumulated_losses_dict = NumbersDict()
|
|
200
|
+
accumulated_inv_norm_factors_dict = NumbersDict()
|
|
191
201
|
step_idx = 0
|
|
192
202
|
|
|
193
203
|
with torch.no_grad():
|
|
@@ -195,29 +205,31 @@ class Engine(EngineBase):
|
|
|
195
205
|
|
|
196
206
|
self._run_step(data)
|
|
197
207
|
train_ctx = rf.get_run_ctx()
|
|
198
|
-
losses_dict = train_ctx.losses
|
|
199
|
-
total_loss = train_ctx.total_loss()
|
|
200
|
-
|
|
201
|
-
total_loss = total_loss.raw_tensor.detach().cpu().numpy()
|
|
202
|
-
losses_dict = {
|
|
203
|
-
dataset_name + "_loss_" + name: float(loss.loss.raw_tensor.detach().cpu().numpy())
|
|
204
|
-
for name, loss in losses_dict.items()
|
|
205
|
-
}
|
|
206
|
-
print("step %i, loss: %f" % (step_idx, total_loss), file=log.v4)
|
|
207
|
-
|
|
208
|
-
accumulated_loss += total_loss
|
|
209
|
-
accumulated_losses_dict += NumbersDict(losses_dict)
|
|
210
|
-
step_idx += 1
|
|
211
208
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
209
|
+
losses_dict = NumbersDict(
|
|
210
|
+
{
|
|
211
|
+
name: float(loss.get_summed_loss().raw_tensor.detach().cpu().numpy())
|
|
212
|
+
for name, loss in train_ctx.losses.items()
|
|
213
|
+
}
|
|
214
|
+
)
|
|
215
|
+
inv_norm_factors_dict = NumbersDict(
|
|
216
|
+
{name: float(_to_raw(loss.get_inv_norm_factor())) for name, loss in train_ctx.losses.items()}
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
accumulated_losses_dict += losses_dict
|
|
220
|
+
accumulated_inv_norm_factors_dict += inv_norm_factors_dict
|
|
221
|
+
print(f"step {step_idx}, loss: {dict(losses_dict / inv_norm_factors_dict)}", file=log.v4)
|
|
222
|
+
step_idx += 1
|
|
215
223
|
|
|
216
|
-
|
|
224
|
+
assert step_idx > 0, f"No data in dataset {dataset_name!r}."
|
|
225
|
+
accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict
|
|
217
226
|
|
|
218
|
-
|
|
227
|
+
self.learning_rate_control.set_epoch_error(
|
|
228
|
+
self.epoch, {f"{dataset_name}_loss_{k}": v for k, v in accumulated_losses_dict.items()}
|
|
229
|
+
)
|
|
230
|
+
self.learning_rate_control.save()
|
|
219
231
|
|
|
220
|
-
|
|
232
|
+
print(f"Total loss for {dataset_name!r}: {dict(accumulated_losses_dict)}", file=log.v3)
|
|
221
233
|
|
|
222
234
|
def _create_data_loader(self, dataset: Dataset) -> DataLoader2:
|
|
223
235
|
"""
|
|
@@ -312,6 +324,7 @@ class Engine(EngineBase):
|
|
|
312
324
|
else:
|
|
313
325
|
raise TypeError(f"get_model returned {model} of type {type(model)}, expected rf.Module or torch.nn.Module")
|
|
314
326
|
assert isinstance(self._pt_model, torch.nn.Module)
|
|
327
|
+
print("Model:", self._pt_model, file=log.v4)
|
|
315
328
|
|
|
316
329
|
if checkpoint_state is not None:
|
|
317
330
|
self._pt_model.load_state_dict(checkpoint_state["model"])
|
|
@@ -404,3 +417,11 @@ class Engine(EngineBase):
|
|
|
404
417
|
os.makedirs(directory, exist_ok=True)
|
|
405
418
|
|
|
406
419
|
self._updater.save_optimizer(filename)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _to_raw(n: Union[int, float, Tensor]):
|
|
423
|
+
if isinstance(n, (int, float)):
|
|
424
|
+
return n
|
|
425
|
+
if isinstance(n, Tensor):
|
|
426
|
+
return n.raw_tensor.detach().cpu().numpy()
|
|
427
|
+
raise TypeError(f"Unexpected {n} of type {type(n)}")
|
|
@@ -116,6 +116,11 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
116
116
|
"""
|
|
117
117
|
return raw_tensor.unsqueeze(axis)
|
|
118
118
|
|
|
119
|
+
@staticmethod
|
|
120
|
+
def cast_raw(raw_tensor: torch.Tensor, dtype: str) -> torch.Tensor:
|
|
121
|
+
"""cast"""
|
|
122
|
+
return raw_tensor.to(dtype=TorchBackend.as_dtype_raw(dtype))
|
|
123
|
+
|
|
119
124
|
@staticmethod
|
|
120
125
|
def activation_raw(raw_tensor: torch.Tensor, func: str) -> torch.Tensor:
|
|
121
126
|
"""
|
|
@@ -411,6 +416,21 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
411
416
|
|
|
412
417
|
return result_tensor
|
|
413
418
|
|
|
419
|
+
@staticmethod
|
|
420
|
+
def range_over_dim(dim: Dim) -> Tensor[torch.Tensor]:
|
|
421
|
+
"""
|
|
422
|
+
:param dim:
|
|
423
|
+
:return: tensor with shape [dim]
|
|
424
|
+
"""
|
|
425
|
+
out = Tensor(
|
|
426
|
+
"range",
|
|
427
|
+
dims=[dim],
|
|
428
|
+
sparse_dim=dim,
|
|
429
|
+
dtype=dim.dyn_size_ext.dtype if dim.dyn_size_ext else rf.get_default_array_index_dtype(),
|
|
430
|
+
)
|
|
431
|
+
out.raw_tensor = torch.arange(dim.get_dim_value())
|
|
432
|
+
return out
|
|
433
|
+
|
|
414
434
|
@staticmethod
|
|
415
435
|
def reduce(
|
|
416
436
|
source: Tensor[torch.Tensor],
|
|
@@ -422,15 +442,25 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
422
442
|
"""reduce"""
|
|
423
443
|
assert mode in Backend._AllowedReduceModes
|
|
424
444
|
if isinstance(axis, Dim):
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
445
|
+
axis = [axis]
|
|
446
|
+
assert all(isinstance(dim, Dim) for dim in axis)
|
|
447
|
+
if use_time_mask is not False and any(dim.need_masking() for dim in axis):
|
|
448
|
+
source = source.copy()
|
|
449
|
+
dtype = source.raw_tensor.dtype
|
|
450
|
+
if mode == "max":
|
|
451
|
+
mask_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
|
|
452
|
+
elif mode == "min":
|
|
453
|
+
mask_value = torch.finfo(dtype).max if dtype.is_floating_point else torch.iinfo(dtype).max
|
|
454
|
+
elif mode == "sum":
|
|
455
|
+
mask_value = 0
|
|
456
|
+
else:
|
|
457
|
+
raise NotImplementedError(f"reduce_{mode} not implemented with masking on tensor {source!r}.")
|
|
458
|
+
for i, dim in enumerate(axis):
|
|
459
|
+
if dim.need_masking():
|
|
460
|
+
mask = source.get_sequence_mask_broadcast(axis=i)
|
|
461
|
+
source.raw_tensor = torch.where(mask, source.raw_tensor, mask_value)
|
|
428
462
|
func = getattr(torch, mode)
|
|
429
|
-
raw_dims = (
|
|
430
|
-
[source.get_axis_from_description(axis)]
|
|
431
|
-
if isinstance(axis, Dim)
|
|
432
|
-
else [source.get_axis_from_description(dim) for dim in axis]
|
|
433
|
-
)
|
|
463
|
+
raw_dims = [source.get_axis_from_description(dim) for dim in axis]
|
|
434
464
|
res_dims = [dim for i, dim in enumerate(source.dims) if i not in raw_dims]
|
|
435
465
|
if not res_dims:
|
|
436
466
|
raw_result = func(source.raw_tensor)
|
|
@@ -76,6 +76,9 @@ class _RFModuleAsPTModule(torch.nn.Module):
|
|
|
76
76
|
pt_mod = rf_module_to_pt_module(rf_mod)
|
|
77
77
|
self.add_module(name, pt_mod)
|
|
78
78
|
|
|
79
|
+
def _get_name(self):
|
|
80
|
+
return self._rf_module.__class__.__name__ + "[RF→PT]"
|
|
81
|
+
|
|
79
82
|
@property
|
|
80
83
|
def rf_module(self) -> rf.Module:
|
|
81
84
|
"""RF module"""
|
|
@@ -139,6 +139,13 @@ def test_demo_torch_task12ax():
|
|
|
139
139
|
# TODO also check FER. So far this is not properly reported. https://github.com/rwth-i6/returnn/issues/1120
|
|
140
140
|
|
|
141
141
|
|
|
142
|
+
@unittest.skipIf(not torch, "no PyTorch")
|
|
143
|
+
def test_demo_rf_torch_task12ax():
|
|
144
|
+
cleanup_tmp_models("demos/demo-rf.config")
|
|
145
|
+
run(py, "rnn.py", "demos/demo-rf.config", print_stdout=True)
|
|
146
|
+
# TODO also check FER. So far this is not properly reported. https://github.com/rwth-i6/returnn/issues/1120
|
|
147
|
+
|
|
148
|
+
|
|
142
149
|
def test_demo_iter_dataset_task12ax():
|
|
143
150
|
# there should be no actual TF dependency, we just iterate the dataset
|
|
144
151
|
cleanup_tmp_models("demos/demo-tf-vanilla-lstm.12ax.config")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|