returnn 1.20240705.144031__tar.gz → 1.20240709.122157__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/PKG-INFO +1 -1
- returnn-1.20240709.122157/_setup_info_generated.py +2 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/meta.py +29 -1
- returnn-1.20240709.122157/returnn/torch/util/gradient_checkpoint.py +594 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn.egg-info/PKG-INFO +1 -1
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn.egg-info/SOURCES.txt +2 -0
- returnn-1.20240709.122157/tests/test_torch_util.py +303 -0
- returnn-1.20240705.144031/_setup_info_generated.py +0 -2
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/.editorconfig +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/.gitignore +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/.gitmodules +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/.kateconfig +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/CHANGELOG.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/CODEOWNERS +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/CONTRIBUTING.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/LICENSE +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/MANIFEST.in +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/README.rst +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/12AX.cluster_map +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/_setup_returnn_env.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-fwd.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-horovod-mpi.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-horovod-mpi.py.sh +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-horovod-mpi.sh +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-hyper-param-tuning.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-iter-dataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-list-devices.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-lua-torch-layer.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-pretrain.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-record-and-push-to-webserver.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-returnn-as-framework.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-rf-pt-benchmark.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-rf.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-rhn-enwik8.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-sprint-interface.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-att-copy.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-attention.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-enc-dec.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-hard-att-copy.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-lstm-benchmark.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-native-lstm.12ax.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-native-lstm2.12ax.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-neural-transducer.12ax.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-rec-explicit-lstm.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-rec-explicit-rnn.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-rec-self-att.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-search-compiled-graph.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-timit-lstm-ctc.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-torch.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo.sh +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/README.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/chars.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/config_demo +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/config_fwd +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/config_real +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/decode.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/go.sh +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/lines.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/split/eval.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/split/train.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/split/valid.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial/create_test_h5.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial/forwardconfig +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial/go.sh +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial/trainconfig +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial_rgb/go.sh +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/pyproject.toml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/requirements.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/__main__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/__old_mod_loader__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/__setup__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/config.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/audio.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/basic.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/bundle_file.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/cached.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/cached2.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/distrib_files.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/generating.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/hdf.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/lm.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/map.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/multi_proc.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/normalization_data.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/numpy_dump.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/raw_wav.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/sprint.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/stereo.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/util/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/util/feature_extraction.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/util/strings.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/util/vocabulary.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/engine/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/engine/base.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/engine/batch.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/__main__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/.git +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/edit.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/reroute.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/select.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/subgraph.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/transform.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/util.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/forward_iface.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_backend.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/backend.cpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/backend.hpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/module.cpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/module.hpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/py_utils.hpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/tensor_ops.cpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/tensor_ops.hpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_numpy_backend.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_random_journal.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_utils.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/array_.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/attention.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/audio/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/audio/mel.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/audio/specaugment.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/backend.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/build_from_dict.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/cond.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/const.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/container.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/control_flow_ctx.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/conv.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/decoder/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/decoder/transformer.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/device.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/dims.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/dropout.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/dtype.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/encoder/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/encoder/base.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/encoder/conformer.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/gradient.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/graph.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/hooks.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/init.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/label_smoothing.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/linear.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/loop.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/loss.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/math_.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/matmul.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/module.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/normalization.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/parameter.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/piecewise_linear.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/rand.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/rec.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/reduce.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/run_ctx.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/signal.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/state.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/stepwise_scheduler.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/tensor_array.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/types.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/import_/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/import_/common.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/import_/git.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/import_/import_.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/learning_rate_control.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/log.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/native_op.cpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/native_op.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/pretrain.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/cache.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/control.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/error_signals.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/extern_interface.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/interface.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/_dim_extra.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/_tensor_extra.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/_tensor_mixin_base.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/_tensor_op_overloads.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/control_flow_ctx.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/dim.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/marked_dim.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/tensor.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/tensor_dict.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/utils.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/compat.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/data_pipeline.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/distributed.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/engine.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/_backend.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/_utils.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/cond.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/dims.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/layer.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/loop.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/make_layer.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/masked_computation.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/parameter_assign.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_low_level/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_low_level/_backend.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/horovod.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/hyper_param_tuning.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/base.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/basic.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/rec.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/segmental_model.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/signal_processing.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/variable.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/native_op.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/network.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/sprint.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/updater.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/basic.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/data.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/gradient_checkpoint.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/ken_lm.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/open_fst.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/extern_data.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/pipeline.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/queued_data_iter.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/tensor_utils.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/distributed.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/engine.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/_backend.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/_rand.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/bridge.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/raw_ops.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/updater.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/util/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/util/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/util/diagnose_gpu.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/util/scaled_gradient.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/__init__.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/basic.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/better_exchook.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/bpe.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/debug.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/debug_helpers.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/file_cache.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/fsa.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/literal_py_to_pickle.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/math.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/multi_proc_non_daemonic_spawn.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/native_code_compiler.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/pprint.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/py-to-pickle.cpp +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/py_compat.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/py_ext_mod_compiler.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/result_with_reason.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/sig_proc.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/task_system.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/train_proc_manager.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/watch_memory.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn.egg-info/dependency_links.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn.egg-info/top_level.txt +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/rnn.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/setup.cfg +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/setup.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/DummySprintExec.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm-inspection-profile.xml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/.gitignore +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/.name +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/misc.xml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/modules.xml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/returnn.iml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/_set_num_threads1.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/_setup_returnn_env.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/_setup_test_env.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/bpe-unicode-demo.codes +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/bpe-unicode-demo.vocab +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lexicon_opt.fst +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lexicon_opt.isyms +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lexicon_opt.jpg +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lexicon_opt.osyms +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lint_common.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/pycharm-inspect.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/pylint.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/returnn-as-framework.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/rf_utils.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/spelling.dic +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Config.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Dataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Fsa.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_GeneratingDataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_HDFDataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_LearningRateControl.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Log.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_MultiProcDataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Pretrain.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_ResNet.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_SprintDataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_SprintInterface.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFEngine.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFNativeOp.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFNetworkLayer.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFNetworkRecLayer.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFNetworkSigProcLayer.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFUpdater.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFUtil.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TF_determinism.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TaskSystem.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TaskSystem_SharedMem.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TranslationDataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Util.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_demos.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_fork_exec.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_hdf_dump.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_array.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_attention.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_base.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_cond.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_const.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_container.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_conv.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_encoder_conformer.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_gradient.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_label_smoothing.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_loop.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_math.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_normalization.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_piecewise_linear.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_rec.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_reduce.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_signal.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_tensor.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_tools.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_torch_dataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_torch_engine.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_torch_frontend.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_torch_internal_frontend.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/_setup_returnn_env.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/analyze-dataset-batches.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bliss-collect-seq-lens.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bliss-dump-text.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bliss-get-segment-names.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bliss-to-ogg-zip.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bpe-create-lexicon.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/calculate-word-error-rate.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/cleanup-old-models.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/collect-orth-symbols.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/collect-words.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/compile_native_op.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/compile_tf_graph.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/debug-dump-search-scores.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/debug-plot-search-scores.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-dataset-raw-strings.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-dataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-forward-stats.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-forward.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-network-json.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-pickle.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/extract_state_tying_from_dataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/get-attention-weights.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/get-best-model-epoch.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/hdf_dump.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/hdf_dump_translation_dataset.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/import-blocks-mt-model.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/import-t2t-mt-model.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/.gitignore +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/Makefile +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/README.md +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/libs_list +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/state_vars_list +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/tensor_names_list +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/file.h +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/main.cc +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/rescorer.h +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/vocabulary.cc +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/vocabulary.h +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/tf_avg_checkpoints.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/tf_inspect_checkpoint.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/tf_inspect_summary_log.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/torch_avg_checkpoints.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/torch_export_to_onnx.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/torch_inspect_checkpoint.py +0 -0
- {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/torch_inspect_checkpoint_and_opt.py +0 -0
|
@@ -1391,6 +1391,7 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1391
1391
|
seq_tag_delim=";",
|
|
1392
1392
|
remove_in_between_postfix=None,
|
|
1393
1393
|
repeat_in_between_last_frame_up_to_multiple_of=None,
|
|
1394
|
+
pad_narrow_data_to_multiple_of_target_len=None,
|
|
1394
1395
|
use_cache_manager=False,
|
|
1395
1396
|
epoch_wise_filter=None,
|
|
1396
1397
|
**kwargs,
|
|
@@ -1406,6 +1407,12 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1406
1407
|
Now it could happen that ceildiv(data_len1 + data_len2, 6) < align_len1 + align_len2.
|
|
1407
1408
|
This option would repeat intermediate ending frames such that data_len1 % 6 == 0,
|
|
1408
1409
|
by setting it to {"data": 6}.
|
|
1410
|
+
:param dict[str,(str,int)]|None pad_narrow_data_to_multiple_of_target_len: data_key -> (target_key, multiple).
|
|
1411
|
+
Similar as repeat_in_between_last_frame_up_to_multiple_of, but works for more padding/alignment schemes.
|
|
1412
|
+
Example: align_len == ceildiv(data_len - P, F) for all your sub-sequences, where P is a custom number,
|
|
1413
|
+
repeat_in_between_last_frame_up_to_multiple_of would not work because align_len != ceildiv(data_len, F)
|
|
1414
|
+
This option would pad/narrow so that align_len * F == data_len for all but the last sub-sequences
|
|
1415
|
+
by setting it to {"data": ("classes", F)} to ensure concat_align_len == ceildiv(concat_data_len - P, F)
|
|
1409
1416
|
:param bool use_cache_manager:
|
|
1410
1417
|
:param dict[(int,int),dict] epoch_wise_filter: see :class:`EpochWiseFilter`
|
|
1411
1418
|
"""
|
|
@@ -1413,6 +1420,7 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1413
1420
|
self.seq_tag_delim = seq_tag_delim
|
|
1414
1421
|
self.remove_in_between_postfix = remove_in_between_postfix or {}
|
|
1415
1422
|
self.repeat_in_between_last_frame_up_to_multiple_of = repeat_in_between_last_frame_up_to_multiple_of or {}
|
|
1423
|
+
self.pad_narrow_data_to_multiple_of_target_len = pad_narrow_data_to_multiple_of_target_len or {}
|
|
1416
1424
|
self.epoch_wise_filter = EpochWiseFilter(epoch_wise_filter) if epoch_wise_filter else None
|
|
1417
1425
|
if isinstance(dataset, dict):
|
|
1418
1426
|
dataset = dataset.copy()
|
|
@@ -1486,7 +1494,7 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1486
1494
|
sub_seq_list.extend(sub_seq_tags)
|
|
1487
1495
|
assert sub_seq_idx == len(sub_seq_list) and len(seq_list) == len(sub_seq_idxs)
|
|
1488
1496
|
self.cur_sub_seq_idxs = sub_seq_idxs
|
|
1489
|
-
return self.sub_dataset.init_seq_order(seq_list=sub_seq_list)
|
|
1497
|
+
return self.sub_dataset.init_seq_order(epoch=epoch, seq_list=sub_seq_list)
|
|
1490
1498
|
|
|
1491
1499
|
def supports_seq_order_sorting(self) -> bool:
|
|
1492
1500
|
"""supports sorting"""
|
|
@@ -1539,6 +1547,11 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1539
1547
|
key,
|
|
1540
1548
|
sub_dataset_keys,
|
|
1541
1549
|
)
|
|
1550
|
+
for key in self.pad_narrow_data_to_multiple_of_target_len:
|
|
1551
|
+
assert key in sub_dataset_keys, (
|
|
1552
|
+
f"{self}: pad_narrow_data_to_multiple_of_target_len key {key}"
|
|
1553
|
+
f" not in sub dataset data-keys {sub_dataset_keys}"
|
|
1554
|
+
)
|
|
1542
1555
|
for sub_seq_idx, sub_seq_tag in zip(sub_seq_idxs, sub_seq_tags):
|
|
1543
1556
|
self.sub_dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1)
|
|
1544
1557
|
sub_dataset_tag = self.sub_dataset.get_tag(sub_seq_idx)
|
|
@@ -1562,6 +1575,17 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1562
1575
|
if data.shape[0] % multiple != 0:
|
|
1563
1576
|
data = numpy.concatenate([data] + [data[-1:]] * (multiple - data.shape[0] % multiple), axis=0)
|
|
1564
1577
|
assert data.shape[0] % multiple == 0
|
|
1578
|
+
if key in self.pad_narrow_data_to_multiple_of_target_len and sub_seq_idx != sub_seq_idxs[-1]:
|
|
1579
|
+
target_key, multiple = self.pad_narrow_data_to_multiple_of_target_len[key]
|
|
1580
|
+
target_data = self.sub_dataset.get_data(sub_seq_idx, target_key)
|
|
1581
|
+
len_diff = data.shape[0] - target_data.shape[0] * multiple
|
|
1582
|
+
if len_diff > 0:
|
|
1583
|
+
# if data longer than ref_data * frame_rate, narrow the data
|
|
1584
|
+
data = data[:-len_diff]
|
|
1585
|
+
elif len_diff < 0:
|
|
1586
|
+
# if data shorter than ref_data * frame_rate, pad by repeating last frame
|
|
1587
|
+
data = numpy.concatenate([data] + [data[-1:]] * -len_diff, axis=0)
|
|
1588
|
+
assert data.shape[0] == target_data.shape[0] * multiple
|
|
1565
1589
|
features[key].append(data)
|
|
1566
1590
|
features = {key: numpy.concatenate(values, axis=0) for (key, values) in features.items()}
|
|
1567
1591
|
return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)
|
|
@@ -1606,6 +1630,10 @@ class ConcatSeqsDataset(CachedDataset2):
|
|
|
1606
1630
|
"""
|
|
1607
1631
|
return self.sub_dataset.get_data_shape(key)
|
|
1608
1632
|
|
|
1633
|
+
def get_total_num_seqs(self) -> int:
|
|
1634
|
+
"""total num seqs"""
|
|
1635
|
+
return len(self.full_seq_list)
|
|
1636
|
+
|
|
1609
1637
|
|
|
1610
1638
|
class ChunkShuffleDataset(CachedDataset2):
|
|
1611
1639
|
"""
|
|
@@ -0,0 +1,594 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gradient checkpointing.
|
|
3
|
+
|
|
4
|
+
Following a lot of the code of the official
|
|
5
|
+
`torch.utils.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`__,
|
|
6
|
+
using ``torch.autograd.graph.saved_tensors_hooks``
|
|
7
|
+
and ``TorchDispatchMode``
|
|
8
|
+
but also handling the RNG fork and reset in a similar way.
|
|
9
|
+
|
|
10
|
+
See also :mod:`returnn.tf.util.gradient_checkpoint`:
|
|
11
|
+
same API and logic in TF, although it heavily makes use
|
|
12
|
+
of the TF computation graph, i.e. graph mode,
|
|
13
|
+
which makes this particular feature much easier to implement.
|
|
14
|
+
|
|
15
|
+
See also:
|
|
16
|
+
https://github.com/rwth-i6/returnn/issues/1552
|
|
17
|
+
https://discuss.pytorch.org/t/gradient-checkpointing/205416
|
|
18
|
+
https://gist.github.com/soulitzer/ec1049a947be046de7fbc2af61a4ee8c
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from typing import Optional, Union, Any, Callable, Sequence, List, Dict
|
|
24
|
+
from types import MethodType
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
|
+
import contextlib
|
|
27
|
+
from weakref import ref, WeakSet
|
|
28
|
+
import threading
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
from torch.utils.weak import WeakTensorKeyDictionary # needs Torch >=2.0.0
|
|
32
|
+
|
|
33
|
+
# noinspection PyProtectedMember
|
|
34
|
+
from torch.utils._python_dispatch import TorchDispatchMode
|
|
35
|
+
|
|
36
|
+
# PyTree is very common and semi-standard for PyTorch, e.g. __torch_dispatch__.
|
|
37
|
+
# We might use dm-tree or so alternatively here, but PyTree should be fine.
|
|
38
|
+
# noinspection PyProtectedMember
|
|
39
|
+
import torch.utils._pytree as pytree
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
__all__ = ["gradient_checkpoint_scope"]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# gradient_checkpoint_scope is the public API to the user.
|
|
46
|
+
# gradient_checkpoint_scope.__enter__ will enter two other scopes:
|
|
47
|
+
#
|
|
48
|
+
# - record_graph_scope: _RecordGraph(TorchDispatchMode),
|
|
49
|
+
# to record the computation graph for all ops within the scope.
|
|
50
|
+
#
|
|
51
|
+
# - saved_tensors_hooks_scope: torch.autograd.graph.saved_tensors_hooks,
|
|
52
|
+
# to overwrite what we store for backpropagation, and how to recompute it.
|
|
53
|
+
# Specifically, for all tensors which were created within the gradient_checkpoint_scope,
|
|
54
|
+
# we will never store them in the pack_hook,
|
|
55
|
+
# and unpack_hook will trigger the recomputation of the computation graph.
|
|
56
|
+
#
|
|
57
|
+
# gradient_checkpoint_scope.__exit__ will exit the record_graph_scope,
|
|
58
|
+
# but the saved_tensors_hooks_scope will stay alive as long as needed,
|
|
59
|
+
# while any of the created tensors are still alive.
|
|
60
|
+
# We keep a weak tensor key dictionary to map from the created raw tensors
|
|
61
|
+
# to the point in the recorded computation graph (specifically _GraphTensor objects).
|
|
62
|
+
# We just check whether any of the weak tensor refs is still alive.
|
|
63
|
+
#
|
|
64
|
+
# To keep saved_tensors_hooks_scope alive and make sure
|
|
65
|
+
# that other calls to torch.autograd.graph.saved_tensors_hooks are correctly handled,
|
|
66
|
+
# specifically that the order of enter/exit is correct,
|
|
67
|
+
# we hook into torch.autograd.graph.saved_tensors_hooks.__enter__/__exit__ itself.
|
|
68
|
+
# See _register_custom_saved_tensors_hooks below.
|
|
69
|
+
# Further, torch.autograd.graph.saved_tensors_hooks is thread local,
|
|
70
|
+
# so we can do any such logic only within the same thread.
|
|
71
|
+
# We also hook into Tensor.__del__ and also handle gradient_checkpoint_scope.__del__,
|
|
72
|
+
# but as that might run in a different thread, we cannot always do the cleanup there.
|
|
73
|
+
# We always check for this.
|
|
74
|
+
# (Note that this is due to the API of torch.autograd.graph.saved_tensors_hooks.
|
|
75
|
+
# We actually would want to always use it for a set of specified tensors.
|
|
76
|
+
# We also discuss some potentially better PyTorch API to implement this in an easier way:
|
|
77
|
+
# https://github.com/pytorch/pytorch/issues/129867)
|
|
78
|
+
#
|
|
79
|
+
# For the recomputation, we make sure that we properly reset the RNG and AMP states,
|
|
80
|
+
# and that we perform the recomputation in the exact same order, such that RNG state is correct.
|
|
81
|
+
#
|
|
82
|
+
# Once some recomputed tensor was used and is not needed anymore, the GC should free it.
|
|
83
|
+
# We try to make sure that no unnecessary references are kept alive.
|
|
84
|
+
#
|
|
85
|
+
# Also see test_gradient_checkpoint_scope() which tests this.
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class gradient_checkpoint_scope:
|
|
89
|
+
"""
|
|
90
|
+
Create a gradient checkpoint scope.
|
|
91
|
+
All tensors created within this scope will not be stored for backpropagation,
|
|
92
|
+
but will be recomputed on the fly during backpropagation.
|
|
93
|
+
|
|
94
|
+
Example::
|
|
95
|
+
|
|
96
|
+
a = ...
|
|
97
|
+
b = ...
|
|
98
|
+
c = ...
|
|
99
|
+
with gradient_checkpoint_scope():
|
|
100
|
+
x = a + b
|
|
101
|
+
y = x * c
|
|
102
|
+
|
|
103
|
+
In this example, the tensor ``x`` will not be stored for backpropagation,
|
|
104
|
+
i.e. the computation ``x = a + b`` will be recomputed during backpropagation.
|
|
105
|
+
|
|
106
|
+
Internally, this uses the PyTorch ``torch.autograd.graph.saved_tensors_hooks`` mechanism
|
|
107
|
+
to override what we store for backpropagation, and how to recompute it.
|
|
108
|
+
And we use the PyTorch ``TorchDispatchMode`` to intercept all operations within the scope.
|
|
109
|
+
Note that the usage of ``torch.autograd.graph.saved_tensors_hooks`` is tricky here
|
|
110
|
+
as we need it beyond the scope of the ``gradient_checkpoint_scope``,
|
|
111
|
+
specifically for all future usages of the tensor ``x`` in the example.
|
|
112
|
+
See the code documentation for more details on this.
|
|
113
|
+
|
|
114
|
+
Note, PyTorch itself also provides a gradient checkpointing API,
|
|
115
|
+
namely `torch.utils.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`__.
|
|
116
|
+
This API is different: You cannot easily specify what not to store / what to recompute.
|
|
117
|
+
You rather specify a start/end point what to *store* for backpropagation,
|
|
118
|
+
and then PyTorch will recompute everything in between.
|
|
119
|
+
For the example above, you define that ``y`` is the end point and will be stored.
|
|
120
|
+
It looks like this::
|
|
121
|
+
|
|
122
|
+
a = ...
|
|
123
|
+
b = ...
|
|
124
|
+
c = ...
|
|
125
|
+
y = torch.utils.checkpoint.checkpoint(lambda: (a + b) * c)
|
|
126
|
+
|
|
127
|
+
PyTorch will not recompute ``... * c`` here,
|
|
128
|
+
but it will recompute ``a + b``.
|
|
129
|
+
We find this API more cumbersome to use and less flexible,
|
|
130
|
+
because in many case, you know what you want to recompute, i.e. what you don't want to store.
|
|
131
|
+
The PyTorch API is more about what you want to store, and then recompute everything else between.
|
|
132
|
+
|
|
133
|
+
See also:
|
|
134
|
+
https://github.com/rwth-i6/returnn/issues/1552
|
|
135
|
+
https://discuss.pytorch.org/t/gradient-checkpointing/205416
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(self):
|
|
139
|
+
self.record_graph_scope = _RecordGraph()
|
|
140
|
+
self.record_graph_scope.graph.gradient_checkpoint_scope_backref = self
|
|
141
|
+
# Note: saved_tensors_hooks is thread local.
|
|
142
|
+
self.saved_tensors_hooks_scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook)
|
|
143
|
+
self.entered = False
|
|
144
|
+
self.entered_thread_ref = None
|
|
145
|
+
self.exit_args: Optional[tuple] = None
|
|
146
|
+
self.exited_saved_tensors_hooks_scope = False
|
|
147
|
+
|
|
148
|
+
def __enter__(self):
|
|
149
|
+
self.record_graph_scope.__enter__()
|
|
150
|
+
self.saved_tensors_hooks_scope.__enter__()
|
|
151
|
+
self.entered = True
|
|
152
|
+
self.entered_thread_ref = ref(threading.current_thread())
|
|
153
|
+
|
|
154
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
155
|
+
self.exit_args = (exc_type, exc_val, exc_tb)
|
|
156
|
+
self.record_graph_scope.__exit__(exc_type, exc_val, exc_tb)
|
|
157
|
+
if self.record_graph_scope.graph.is_any_recorded_tensor_alive():
|
|
158
|
+
# Do not exit saved_tensors_hooks_scope here
|
|
159
|
+
# because we still want to pack any tensors which were captured in our graph
|
|
160
|
+
# by giving it a ref to the graph tensor.
|
|
161
|
+
# However, we must track any further external calls to saved_tensors_hooks_scope,
|
|
162
|
+
# to be able to properly remove it from the stack at the right point.
|
|
163
|
+
_register_custom_saved_tensors_hooks(existing_scope=self.saved_tensors_hooks_scope)
|
|
164
|
+
_register_custom_saved_tensors_hooks_thread_local_callback(
|
|
165
|
+
_WeakMethod(self._custom_saved_tensors_hooks_callback, return_if_dead=False)
|
|
166
|
+
)
|
|
167
|
+
else: # no relevant tensors alive anymore
|
|
168
|
+
self.exit_saved_tensors_hooks_scope()
|
|
169
|
+
|
|
170
|
+
def _maybe_exit_saved_tensors_hooks_scope(self):
|
|
171
|
+
if self.exited_saved_tensors_hooks_scope:
|
|
172
|
+
return
|
|
173
|
+
if not self.exit_args:
|
|
174
|
+
return
|
|
175
|
+
# If we are in the right thread, maybe we can do the cleanup now.
|
|
176
|
+
if self.entered_thread_ref() is threading.current_thread():
|
|
177
|
+
if not self.record_graph_scope.graph.is_any_recorded_tensor_alive():
|
|
178
|
+
self.exit_saved_tensors_hooks_scope()
|
|
179
|
+
|
|
180
|
+
def __del__(self):
|
|
181
|
+
# Note, be very careful what we do in __del__ because it might be called in a different thread!
|
|
182
|
+
# Note that the __del__ will likely be called very late,
|
|
183
|
+
# as the reference to the _Graph is kept alive until we used it for backprop,
|
|
184
|
+
# as we keep this alive via _Graph.gradient_checkpoint_scope_backref
|
|
185
|
+
# as long as any _GraphTensor is alive due to backprop pack_hook.
|
|
186
|
+
self._maybe_exit_saved_tensors_hooks_scope()
|
|
187
|
+
|
|
188
|
+
def exit_saved_tensors_hooks_scope(self):
|
|
189
|
+
"""
|
|
190
|
+
exit saved_tensors_hooks_scope if not yet done.
|
|
191
|
+
"""
|
|
192
|
+
assert self.entered_thread_ref() is threading.current_thread()
|
|
193
|
+
if self.exit_args and not self.exited_saved_tensors_hooks_scope:
|
|
194
|
+
# Note that via _register_custom_saved_tensors_hooks,
|
|
195
|
+
# this saved_tensors_hooks_scope.__exit__ might get to our _custom_saved_tensors_hooks_exit below,
|
|
196
|
+
# which will make sure that the order of __exit__ is correct.
|
|
197
|
+
self.exited_saved_tensors_hooks_scope = True
|
|
198
|
+
self.saved_tensors_hooks_scope.__exit__(*self.exit_args)
|
|
199
|
+
|
|
200
|
+
def _pack_hook(self, x: torch.Tensor) -> Union[torch.Tensor, _GraphTensor]:
|
|
201
|
+
if self.exit_args and not self.record_graph_scope.graph.is_any_recorded_tensor_alive():
|
|
202
|
+
# No raw tensors alive anymore in graph_tensor_from_raw_tensor,
|
|
203
|
+
# so we can exit saved_tensors_hooks_scope now.
|
|
204
|
+
# (We might not always catch this properly in the Tensor _DelHook,
|
|
205
|
+
# e.g. when Tensor.__del__ runs in a different thread.)
|
|
206
|
+
self.exit_saved_tensors_hooks_scope()
|
|
207
|
+
return x
|
|
208
|
+
# _RecordGraph.__torch_dispatch__ should have recorded all newly created tensors.
|
|
209
|
+
x_ = self.record_graph_scope.graph.graph_tensor_from_weak_raw_tensor.get(x, x)
|
|
210
|
+
if isinstance(x_, _GraphTensor):
|
|
211
|
+
x._RETURNN_grad_ckpt_del_hook = _DelHook(_WeakMethod(self._tensor_del_hook))
|
|
212
|
+
return x_
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def _unpack_hook(x: Union[torch.Tensor, _GraphTensor]) -> torch.Tensor:
|
|
216
|
+
if isinstance(x, _GraphTensor):
|
|
217
|
+
x.op.graph.gradient_checkpoint_scope_backref._maybe_exit_saved_tensors_hooks_scope()
|
|
218
|
+
x.op.graph.maybe_recompute()
|
|
219
|
+
return x.get_recomputed()
|
|
220
|
+
return x
|
|
221
|
+
|
|
222
|
+
def _tensor_del_hook(self):
|
|
223
|
+
# Some of the relevant tensors got deleted.
|
|
224
|
+
# If we are in the right thread, maybe we can do the cleanup now.
|
|
225
|
+
self._maybe_exit_saved_tensors_hooks_scope()
|
|
226
|
+
|
|
227
|
+
def _custom_saved_tensors_hooks_callback(self) -> bool:
|
|
228
|
+
assert self.entered_thread_ref() is threading.current_thread()
|
|
229
|
+
assert self.exit_args
|
|
230
|
+
if self.record_graph_scope.graph.is_any_recorded_tensor_alive():
|
|
231
|
+
return True # keep callback alive
|
|
232
|
+
else:
|
|
233
|
+
self.exit_saved_tensors_hooks_scope()
|
|
234
|
+
return False # we are done, can delete callback
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class _RecordGraph(TorchDispatchMode):
|
|
238
|
+
def __init__(self):
|
|
239
|
+
super().__init__()
|
|
240
|
+
self.graph = _Graph([])
|
|
241
|
+
|
|
242
|
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
243
|
+
kwargs = {} if kwargs is None else kwargs
|
|
244
|
+
graph = self.graph
|
|
245
|
+
graph.maybe_store_rng_state(torch.device("cpu"))
|
|
246
|
+
graph.maybe_store_amp_state(torch.device("cpu"))
|
|
247
|
+
pytree.tree_map(graph.maybe_store_rng_state, args)
|
|
248
|
+
pytree.tree_map(graph.maybe_store_rng_state, kwargs)
|
|
249
|
+
out = func(*args, **kwargs)
|
|
250
|
+
graph.record_op(func, args, kwargs, out)
|
|
251
|
+
return out
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@dataclass
|
|
255
|
+
class _Graph:
|
|
256
|
+
ops_to_be_recomputed: List[_GraphOp] = field(default_factory=list)
|
|
257
|
+
graph_tensor_from_weak_raw_tensor: WeakTensorKeyDictionary[torch.Tensor, _GraphTensor] = field(
|
|
258
|
+
default_factory=WeakTensorKeyDictionary
|
|
259
|
+
)
|
|
260
|
+
stored_device_rng_states: Dict[torch.device, Any] = field(default_factory=dict)
|
|
261
|
+
stored_device_amp_states: Dict[torch.device, Any] = field(default_factory=dict)
|
|
262
|
+
# Keep scope alive as long as any _GraphTensor is alive due to backprop pack_hook.
|
|
263
|
+
gradient_checkpoint_scope_backref: Optional[gradient_checkpoint_scope] = None
|
|
264
|
+
|
|
265
|
+
def is_any_recorded_tensor_alive(self) -> bool:
|
|
266
|
+
"""
|
|
267
|
+
:return: any recorded tensor is still alive.
|
|
268
|
+
Recorded tensors are outputs from any ops which were recorded,
|
|
269
|
+
i.e. ops under the gradient_checkpoint_scope.
|
|
270
|
+
"""
|
|
271
|
+
# graph_tensor_from_weak_raw_tensor is a WeakTensorKeyDictionary,
|
|
272
|
+
# i.e. once there is no other strong reference to some Tensor anymore,
|
|
273
|
+
# it would also be removed from graph_tensor_from_weak_raw_tensor.
|
|
274
|
+
return bool(self.graph_tensor_from_weak_raw_tensor)
|
|
275
|
+
|
|
276
|
+
def record_op(self, func: Any, args: Sequence[Any], kwargs: Dict[str, Any], out: Any):
|
|
277
|
+
"""record op"""
|
|
278
|
+
out_flat, _ = pytree.tree_flatten(out)
|
|
279
|
+
wrapped_args = pytree.tree_map_only(torch.Tensor, self.maybe_map_raw_tensor_to_graph_tensor, args)
|
|
280
|
+
wrapped_kwargs = pytree.tree_map_only(torch.Tensor, self.maybe_map_raw_tensor_to_graph_tensor, kwargs)
|
|
281
|
+
op = _GraphOp(
|
|
282
|
+
graph=self,
|
|
283
|
+
func=func,
|
|
284
|
+
args=wrapped_args,
|
|
285
|
+
kwargs=wrapped_kwargs,
|
|
286
|
+
out_flat_num=len(out_flat),
|
|
287
|
+
)
|
|
288
|
+
self.ops_to_be_recomputed.append(op)
|
|
289
|
+
for i, out_flat_elem in enumerate(out_flat):
|
|
290
|
+
if isinstance(out_flat_elem, torch.Tensor):
|
|
291
|
+
if out_flat_elem in self.graph_tensor_from_weak_raw_tensor:
|
|
292
|
+
continue
|
|
293
|
+
tensor_ = _GraphTensor(op=op, out_flat_idx=i)
|
|
294
|
+
self.graph_tensor_from_weak_raw_tensor[out_flat_elem] = tensor_
|
|
295
|
+
|
|
296
|
+
def maybe_store_rng_state(self, arg: Any):
|
|
297
|
+
"""
|
|
298
|
+
Store RNG state if not yet stored for this device.
|
|
299
|
+
We store it only once for the first usage,
|
|
300
|
+
as we only restore it once for the recomputation,
|
|
301
|
+
and then we rely on performing the recomputation in the correct order,
|
|
302
|
+
which should be deterministic and lead to the same RNG output.
|
|
303
|
+
"""
|
|
304
|
+
if isinstance(arg, torch.Tensor):
|
|
305
|
+
device = arg.device
|
|
306
|
+
elif isinstance(arg, torch.device):
|
|
307
|
+
device = arg
|
|
308
|
+
else:
|
|
309
|
+
return
|
|
310
|
+
if device not in self.stored_device_rng_states:
|
|
311
|
+
self.stored_device_rng_states[device] = _get_dev_rng_state(device)
|
|
312
|
+
|
|
313
|
+
def maybe_store_amp_state(self, arg: Any):
|
|
314
|
+
"""store AMP state if not yet stored for this device."""
|
|
315
|
+
if isinstance(arg, torch.Tensor):
|
|
316
|
+
device = arg.device
|
|
317
|
+
elif isinstance(arg, torch.device):
|
|
318
|
+
device = arg
|
|
319
|
+
else:
|
|
320
|
+
return
|
|
321
|
+
if device not in self.stored_device_amp_states:
|
|
322
|
+
self.stored_device_amp_states[device] = _get_dev_amp_state(device)
|
|
323
|
+
|
|
324
|
+
def maybe_map_raw_tensor_to_graph_tensor(self, tensor: torch.Tensor) -> Union[_GraphTensor, torch.Tensor]:
|
|
325
|
+
"""raw tensor to graph tensor if available, otherwise return raw tensor."""
|
|
326
|
+
return self.graph_tensor_from_weak_raw_tensor.get(tensor, tensor)
|
|
327
|
+
|
|
328
|
+
def maybe_recompute(self):
|
|
329
|
+
"""
|
|
330
|
+
Recompute.
|
|
331
|
+
|
|
332
|
+
Make sure that the recomputations happen in the correct order,
|
|
333
|
+
to get any random number generator state correct.
|
|
334
|
+
|
|
335
|
+
Note that we considered to have an API here which allowed to only recompute a subset of the ops.
|
|
336
|
+
It would still compute all from op idx 0 to some given op idx, but not the rest.
|
|
337
|
+
On subsequent calls, it would then continue from the last idx until again the requested op idx.
|
|
338
|
+
This works fine except of one important aspect: The RNG state.
|
|
339
|
+
If there are any other ops in between which use the RNG state, the RNG state would not be correct anymore.
|
|
340
|
+
To allow this, we then would need to get the RNG state again and reset it later again,
|
|
341
|
+
which would add some further overhead.
|
|
342
|
+
To keep things simple and to avoid this overhead, we recompute all ops together right now.
|
|
343
|
+
|
|
344
|
+
However, we can at least remove the op from the list once it is computed.
|
|
345
|
+
So once any referenced tensor is not needed anymore, it can be garbage collected.
|
|
346
|
+
"""
|
|
347
|
+
if not self.ops_to_be_recomputed:
|
|
348
|
+
return
|
|
349
|
+
with _reset_rng_states_scope(self.stored_device_rng_states), _reset_amp_states_scope(
|
|
350
|
+
self.stored_device_amp_states
|
|
351
|
+
):
|
|
352
|
+
ops_reversed_queue = list(self.ops_to_be_recomputed)
|
|
353
|
+
ops_reversed_queue.reverse()
|
|
354
|
+
self.ops_to_be_recomputed.clear()
|
|
355
|
+
while ops_reversed_queue:
|
|
356
|
+
op = ops_reversed_queue.pop(-1)
|
|
357
|
+
op.recompute()
|
|
358
|
+
self.stored_device_rng_states.clear()
|
|
359
|
+
self.stored_device_amp_states.clear()
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@dataclass
|
|
363
|
+
class _GraphOp:
|
|
364
|
+
graph: _Graph
|
|
365
|
+
func: Any
|
|
366
|
+
args: Optional[Sequence[Union[_GraphTensor, Any]]]
|
|
367
|
+
kwargs: Optional[Dict[str, Union[_GraphTensor, Any]]]
|
|
368
|
+
out_flat_num: int
|
|
369
|
+
recomputed_out_flat: Optional[Sequence[torch.Tensor]] = None
|
|
370
|
+
|
|
371
|
+
def recompute(self):
|
|
372
|
+
"""recompute, assuming all args are recomputed."""
|
|
373
|
+
args = pytree.tree_map_only(_GraphTensor, _GraphTensor.get_recomputed, self.args)
|
|
374
|
+
kwargs = pytree.tree_map_only(_GraphTensor, _GraphTensor.get_recomputed, self.kwargs)
|
|
375
|
+
out = self.func(*args, **kwargs)
|
|
376
|
+
out_flat, _ = pytree.tree_flatten(out)
|
|
377
|
+
assert len(out_flat) == self.out_flat_num
|
|
378
|
+
self.recomputed_out_flat = out_flat
|
|
379
|
+
# potentially free any referenced resources. we don't need them anymore.
|
|
380
|
+
self.args = None
|
|
381
|
+
self.kwargs = None
|
|
382
|
+
# self.func should be ok to keep, should ref some of the low-level aten functions
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@dataclass
|
|
386
|
+
class _GraphTensor:
|
|
387
|
+
op: _GraphOp
|
|
388
|
+
out_flat_idx: int
|
|
389
|
+
|
|
390
|
+
def get_recomputed(self) -> torch.Tensor:
|
|
391
|
+
"""assuming it was recomputed, return the raw tensor."""
|
|
392
|
+
assert self.op.recomputed_out_flat is not None
|
|
393
|
+
return self.op.recomputed_out_flat[self.out_flat_idx]
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
@contextlib.contextmanager
|
|
397
|
+
def _reset_rng_states_scope(states: Dict[torch.device, Any]):
|
|
398
|
+
"""
|
|
399
|
+
Reset RNG states scope.
|
|
400
|
+
Like torch.random.fork_rng but simpler.
|
|
401
|
+
"""
|
|
402
|
+
prev_states = {dev: _get_dev_rng_state(dev) for dev in states.keys()}
|
|
403
|
+
try:
|
|
404
|
+
for dev, state in states.items():
|
|
405
|
+
_set_dev_rng_state(dev, state)
|
|
406
|
+
yield
|
|
407
|
+
finally:
|
|
408
|
+
for dev, state in prev_states.items():
|
|
409
|
+
_set_dev_rng_state(dev, state)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def _get_dev_rng_state(dev: torch.device):
|
|
413
|
+
if dev.type == "cpu":
|
|
414
|
+
return torch.get_rng_state()
|
|
415
|
+
dev_mod = getattr(torch, dev.type)
|
|
416
|
+
return dev_mod.get_rng_state(dev)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _set_dev_rng_state(dev: torch.device, state: Any):
|
|
420
|
+
if dev.type == "cpu":
|
|
421
|
+
torch.set_rng_state(state)
|
|
422
|
+
else:
|
|
423
|
+
dev_mod = getattr(torch, dev.type)
|
|
424
|
+
dev_mod.set_rng_state(state, dev)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
@contextlib.contextmanager
|
|
428
|
+
def _reset_amp_states_scope(states: Dict[torch.device, Any]):
|
|
429
|
+
with contextlib.ExitStack() as stack:
|
|
430
|
+
for dev, state in states.items():
|
|
431
|
+
if not state:
|
|
432
|
+
continue
|
|
433
|
+
if dev.type == "cpu":
|
|
434
|
+
stack.enter_context(torch.cpu.amp.autocast(**state))
|
|
435
|
+
else:
|
|
436
|
+
device_module = getattr(torch, dev.type)
|
|
437
|
+
stack.enter_context(device_module.amp.autocast(**state))
|
|
438
|
+
yield
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def _get_dev_amp_state(dev: torch.device):
|
|
442
|
+
if dev.type == "cpu":
|
|
443
|
+
if not torch.is_autocast_cpu_enabled():
|
|
444
|
+
return None
|
|
445
|
+
return {
|
|
446
|
+
"dtype": torch.get_autocast_cpu_dtype(),
|
|
447
|
+
"cache_enabled": torch.is_autocast_cache_enabled(),
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
if dev.type == "cuda":
|
|
451
|
+
if not torch.is_autocast_enabled():
|
|
452
|
+
return None
|
|
453
|
+
return {
|
|
454
|
+
"dtype": torch.get_autocast_gpu_dtype(),
|
|
455
|
+
"cache_enabled": torch.is_autocast_cache_enabled(),
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
device_module = getattr(torch, dev.type)
|
|
459
|
+
if hasattr(device_module, "is_autocast_enabled") and hasattr(device_module, "get_autocast_dtype"):
|
|
460
|
+
if not device_module.is_autocast_enabled():
|
|
461
|
+
return None
|
|
462
|
+
return {
|
|
463
|
+
"dtype": device_module.get_autocast_dtype(),
|
|
464
|
+
"cache_enabled": torch.is_autocast_cache_enabled(),
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
return None
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
class _DelHook:
|
|
471
|
+
def __init__(self, callback):
|
|
472
|
+
self.callback = callback
|
|
473
|
+
|
|
474
|
+
def __del__(self):
|
|
475
|
+
self.callback()
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
class _WeakMethod:
|
|
479
|
+
# wrong type hint because mypy/PyCharm don't handle MethodType well
|
|
480
|
+
def __init__(self, method: Union[MethodType, Callable], *, return_if_dead: Any = None):
|
|
481
|
+
assert isinstance(method, MethodType)
|
|
482
|
+
self.obj = ref(method.__self__)
|
|
483
|
+
self.func = method.__func__
|
|
484
|
+
self.return_if_dead = return_if_dead
|
|
485
|
+
|
|
486
|
+
def __call__(self, *args, **kwargs):
|
|
487
|
+
obj = self.obj()
|
|
488
|
+
if obj is None:
|
|
489
|
+
return self.return_if_dead
|
|
490
|
+
return self.func(obj, *args, **kwargs)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
_orig_saved_tensors_hooks_enter = torch.autograd.graph.saved_tensors_hooks.__enter__
|
|
494
|
+
_orig_saved_tensors_hooks_exit = torch.autograd.graph.saved_tensors_hooks.__exit__
|
|
495
|
+
_custom_saved_tensors_hooks_tls_ctx = threading.local()
|
|
496
|
+
_custom_saved_tensors_hooks_lock = threading.Lock() # only needed for non thread-locals, i.e. threads, methods
|
|
497
|
+
_custom_saved_tensors_hooks_registered_threads = WeakSet()
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _register_custom_saved_tensors_hooks(*, existing_scope: torch.autograd.graph.saved_tensors_hooks):
|
|
501
|
+
"""
|
|
502
|
+
The purpose of our custom saved_tensors_hooks __enter__/__exit__ is to make sure that
|
|
503
|
+
the order of __exit__ is correct, i.e. that we exit the scope in the correct order.
|
|
504
|
+
|
|
505
|
+
See :func:`_custom_saved_tensors_hooks_enter` and :func:`_custom_saved_tensors_hooks_exit`.
|
|
506
|
+
|
|
507
|
+
There is no need to call :func:`_unregister_custom_saved_tensors_hooks` later.
|
|
508
|
+
It will be called automatically when the last scope is exited.
|
|
509
|
+
"""
|
|
510
|
+
thread = threading.current_thread()
|
|
511
|
+
with _custom_saved_tensors_hooks_lock:
|
|
512
|
+
if thread in _custom_saved_tensors_hooks_registered_threads:
|
|
513
|
+
return
|
|
514
|
+
if getattr(_custom_saved_tensors_hooks_tls_ctx, "stack", None) is None:
|
|
515
|
+
_custom_saved_tensors_hooks_tls_ctx.stack = []
|
|
516
|
+
_custom_saved_tensors_hooks_tls_ctx.in_callback = False
|
|
517
|
+
_custom_saved_tensors_hooks_tls_ctx.callbacks = []
|
|
518
|
+
_custom_saved_tensors_hooks_tls_ctx.queued_exits = []
|
|
519
|
+
_custom_saved_tensors_hooks_tls_ctx.active = True
|
|
520
|
+
_custom_saved_tensors_hooks_tls_ctx.stack.append(existing_scope)
|
|
521
|
+
_custom_saved_tensors_hooks_registered_threads.add(thread)
|
|
522
|
+
if len(_custom_saved_tensors_hooks_registered_threads) == 1:
|
|
523
|
+
torch.autograd.graph.saved_tensors_hooks.__enter__ = _custom_saved_tensors_hooks_enter
|
|
524
|
+
torch.autograd.graph.saved_tensors_hooks.__exit__ = _custom_saved_tensors_hooks_exit
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _unregister_custom_saved_tensors_hooks():
|
|
528
|
+
thread = threading.current_thread()
|
|
529
|
+
with _custom_saved_tensors_hooks_lock:
|
|
530
|
+
assert thread in _custom_saved_tensors_hooks_registered_threads
|
|
531
|
+
assert (
|
|
532
|
+
not _custom_saved_tensors_hooks_tls_ctx.stack
|
|
533
|
+
and not _custom_saved_tensors_hooks_tls_ctx.callbacks
|
|
534
|
+
and not _custom_saved_tensors_hooks_tls_ctx.queued_exits
|
|
535
|
+
)
|
|
536
|
+
_custom_saved_tensors_hooks_tls_ctx.active = False
|
|
537
|
+
_custom_saved_tensors_hooks_registered_threads.remove(thread)
|
|
538
|
+
if not _custom_saved_tensors_hooks_registered_threads:
|
|
539
|
+
torch.autograd.graph.saved_tensors_hooks.__enter__ = _orig_saved_tensors_hooks_enter
|
|
540
|
+
torch.autograd.graph.saved_tensors_hooks.__exit__ = _orig_saved_tensors_hooks_exit
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def _custom_saved_tensors_hooks_enter(self: torch.autograd.graph.saved_tensors_hooks):
|
|
544
|
+
_custom_saved_tensors_hooks_call_callbacks()
|
|
545
|
+
# The callbacks might have unregistered us. Only add to the stack if we are still active.
|
|
546
|
+
if _custom_saved_tensors_hooks_tls_ctx.active:
|
|
547
|
+
_custom_saved_tensors_hooks_tls_ctx.stack.append(self)
|
|
548
|
+
return _orig_saved_tensors_hooks_enter(self)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def _custom_saved_tensors_hooks_exit(self: torch.autograd.graph.saved_tensors_hooks, exc_type, exc_val, exc_tb):
|
|
552
|
+
if self not in _custom_saved_tensors_hooks_tls_ctx.stack:
|
|
553
|
+
raise Exception(
|
|
554
|
+
f"saved_tensors_hooks __exit__ mismatch."
|
|
555
|
+
f" stack {_custom_saved_tensors_hooks_tls_ctx.stack},"
|
|
556
|
+
f" queued_exits {_custom_saved_tensors_hooks_tls_ctx.queued_exits},"
|
|
557
|
+
f" got self {self}"
|
|
558
|
+
)
|
|
559
|
+
_custom_saved_tensors_hooks_tls_ctx.queued_exits.append(self)
|
|
560
|
+
_custom_saved_tensors_hooks_call_callbacks()
|
|
561
|
+
while _custom_saved_tensors_hooks_tls_ctx.stack:
|
|
562
|
+
scope = _custom_saved_tensors_hooks_tls_ctx.stack[-1]
|
|
563
|
+
if scope not in _custom_saved_tensors_hooks_tls_ctx.queued_exits:
|
|
564
|
+
# Need to wait for this scope to exit first.
|
|
565
|
+
# Once we exit it, we would then exit also the others when they are on top.
|
|
566
|
+
break
|
|
567
|
+
_custom_saved_tensors_hooks_tls_ctx.stack.pop(-1)
|
|
568
|
+
_custom_saved_tensors_hooks_tls_ctx.queued_exits.remove(scope)
|
|
569
|
+
_orig_saved_tensors_hooks_exit(scope, exc_type, exc_val, exc_tb)
|
|
570
|
+
exc_type, exc_val, exc_tb = None, None, None # do not propagate this again (even though it's ignored anyway)
|
|
571
|
+
if not _custom_saved_tensors_hooks_tls_ctx.stack:
|
|
572
|
+
assert not _custom_saved_tensors_hooks_tls_ctx.queued_exits
|
|
573
|
+
if _custom_saved_tensors_hooks_tls_ctx.active: # might have been unregistered in the meantime by callbacks
|
|
574
|
+
_unregister_custom_saved_tensors_hooks()
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def _register_custom_saved_tensors_hooks_thread_local_callback(cb: Callable[[], bool]):
|
|
578
|
+
"""
|
|
579
|
+
Register some thread-local callback function which is called on saved_tensors_hooks __enter__ and __exit__.
|
|
580
|
+
If it returns True, it is kept alive, otherwise removed.
|
|
581
|
+
"""
|
|
582
|
+
_custom_saved_tensors_hooks_tls_ctx.callbacks.append(cb)
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _custom_saved_tensors_hooks_call_callbacks():
|
|
586
|
+
if _custom_saved_tensors_hooks_tls_ctx.in_callback:
|
|
587
|
+
return # avoid recursive calls
|
|
588
|
+
try:
|
|
589
|
+
_custom_saved_tensors_hooks_tls_ctx.in_callback = True
|
|
590
|
+
_custom_saved_tensors_hooks_tls_ctx.callbacks = [
|
|
591
|
+
cb for cb in _custom_saved_tensors_hooks_tls_ctx.callbacks if cb()
|
|
592
|
+
]
|
|
593
|
+
finally:
|
|
594
|
+
_custom_saved_tensors_hooks_tls_ctx.in_callback = False
|