returnn 1.20251023.135024__tar.gz → 1.20251118.160612__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- {returnn-1.20251023.135024/returnn.egg-info → returnn-1.20251118.160612}/PKG-INFO +1 -1
- returnn-1.20251118.160612/_setup_info_generated.py +2 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/config.py +1 -1
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/generating.py +3 -5
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/lm.py +20 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/meta.py +179 -60
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/util/vocabulary.py +90 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/attention.py +1 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/encoder/conformer.py +1 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/module.py +8 -1
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/engine.py +37 -3
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/basic.py +3 -6
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/better_exchook.py +4 -0
- returnn-1.20251118.160612/returnn/util/collect_outputs_dict.py +79 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/file_cache.py +15 -1
- {returnn-1.20251023.135024 → returnn-1.20251118.160612/returnn.egg-info}/PKG-INFO +1 -1
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn.egg-info/SOURCES.txt +1 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_Dataset.py +128 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_Util.py +19 -3
- returnn-1.20251023.135024/_setup_info_generated.py +0 -2
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/.editorconfig +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/.gitignore +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/.gitmodules +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/.kateconfig +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/CHANGELOG.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/CODEOWNERS +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/CONTRIBUTING.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/LICENSE +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/MANIFEST.in +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/README.rst +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/12AX.cluster_map +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/_setup_returnn_env.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-fwd.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-horovod-mpi.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-horovod-mpi.py.sh +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-horovod-mpi.sh +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-hyper-param-tuning.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-iter-dataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-list-devices.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-lua-torch-layer.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-pretrain.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-record-and-push-to-webserver.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-returnn-as-framework.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-rf-pt-benchmark.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-rf.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-rhn-enwik8.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-sprint-interface.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-att-copy.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-attention.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-enc-dec.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-hard-att-copy.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-lstm-benchmark.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-native-lstm.12ax.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-native-lstm2.12ax.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-neural-transducer.12ax.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-rec-explicit-lstm.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-rec-explicit-rnn.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-rec-self-att.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-search-compiled-graph.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-timit-lstm-ctc.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-torch.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/demo.sh +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/README.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/chars.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/config_demo +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/config_fwd +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/config_real +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/decode.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/go.sh +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/lines.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/split/eval.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/split/train.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/IAM/split/valid.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/artificial/create_test_h5.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/artificial/forwardconfig +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/artificial/go.sh +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/artificial/trainconfig +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/artificial_rgb/go.sh +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/pyproject.toml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/requirements.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/__main__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/__old_mod_loader__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/__setup__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/audio.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/basic.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/bundle_file.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/cached.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/cached2.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/distrib_files.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/hdf.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/huggingface.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/map.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/multi_proc.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/normalization_data.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/numpy_dump.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/postprocessing.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/raw_wav.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/sprint.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/stereo.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/text_dict.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/util/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/util/feature_extraction.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/datasets/util/strings.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/engine/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/engine/base.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/engine/batch.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/__main__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/.git +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/graph_editor/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/graph_editor/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/graph_editor/edit.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/graph_editor/reroute.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/graph_editor/select.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/graph_editor/subgraph.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/graph_editor/transform.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/extern/graph_editor/util.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/forward_iface.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_backend.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_cache.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_native/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_native/backend.cpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_native/backend.hpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_native/module.cpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_native/module.hpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_native/py_utils.hpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_native/tensor_ops.cpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_native/tensor_ops.hpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_numpy_backend.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_random_journal.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/_utils.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/array_.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/audio/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/audio/mel.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/audio/specaugment.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/backend.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/build_from_dict.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/cond.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/const.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/container.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/control_flow_ctx.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/conv.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/conversions/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/conversions/espnet_e_branchformer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/conversions/hf_llama.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/conversions/torch_nn.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/decoder/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/decoder/transformer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/device.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/dims.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/dropout.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/dtype.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/encoder/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/encoder/base.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/encoder/conformer_v2.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/encoder/e_branchformer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/encoder/transformer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/gradient.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/graph.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/hooks.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/init.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/label_smoothing.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/linear.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/loop.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/loss.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/math_.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/matmul.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/nested.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/normalization.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/parameter.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/parametrizations.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/parametrize.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/piecewise_linear.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/rand.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/rec.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/reduce.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/run_ctx.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/signal.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/state.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/stepwise_scheduler.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/tensor_array.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/frontend/types.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/import_/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/import_/common.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/import_/git.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/import_/import_.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/learning_rate_control.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/log.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/native_op.cpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/native_op.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/pretrain.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/sprint/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/sprint/cache.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/sprint/control.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/sprint/error_signals.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/sprint/extern_interface.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/sprint/interface.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/_dim_extra.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/_tensor_extra.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/_tensor_mixin_base.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/_tensor_op_overloads.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/control_flow_ctx.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/dim.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/marked_dim.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/tensor.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/tensor_dict.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tensor/utils.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/compat.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/data_pipeline.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/distributed.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/engine.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/_backend.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/_utils.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/cond.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/dims.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/layer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/loop.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/make_layer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/masked_computation.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/parameter_assign.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_low_level/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/frontend_low_level/_backend.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/horovod.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/hyper_param_tuning.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/layers/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/layers/base.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/layers/basic.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/layers/rec.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/layers/segmental_model.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/layers/signal_processing.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/layers/variable.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/native_op.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/network.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/sprint.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/updater.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/util/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/util/basic.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/util/data.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/util/gradient_checkpoint.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/util/ken_lm.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/tf/util/open_fst.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/data/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/data/extern_data.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/data/pipeline.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/data/queued_data_iter.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/data/tensor_utils.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/distributed.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/frontend/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/frontend/_backend.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/frontend/_rand.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/frontend/bridge.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/frontend/raw_ops.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/optim/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/optim/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/optim/lion.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/updater.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/util/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/util/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/util/array_.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/util/debug_inf_nan.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/util/diagnose_gpu.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/util/exception_helper.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/util/gradient_checkpoint.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/util/module.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/torch/util/scaled_gradient.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/__init__.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/bpe.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/debug.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/debug_helpers.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/fsa.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/literal_py_to_pickle.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/lru_cache.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/math.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/multi_proc_non_daemonic_spawn.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/native_code_compiler.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/pprint.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/py-to-pickle.cpp +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/py_ext_mod_compiler.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/result_with_reason.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/sig_proc.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/task_system.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/train_proc_manager.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn/util/watch_memory.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn.egg-info/dependency_links.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn.egg-info/requires.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/returnn.egg-info/top_level.txt +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/rnn.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/setup.cfg +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/setup.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/DummySprintExec.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm-inspection-profile.xml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/.gitignore +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/.name +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/misc.xml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/modules.xml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/returnn.iml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/_set_num_threads1.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/_setup_returnn_env.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/_setup_test_env.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/bpe-unicode-demo.codes +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/bpe-unicode-demo.vocab +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/lexicon_opt.fst +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/lexicon_opt.isyms +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/lexicon_opt.jpg +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/lexicon_opt.osyms +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/lint_common.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/pycharm-inspect.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/pylint.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/returnn-as-framework.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/rf_utils.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/spelling.dic +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_Config.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_Fsa.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_GeneratingDataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_HDFDataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_LearningRateControl.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_Log.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_MultiProcDataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_Pretrain.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_ResNet.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_SprintDataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_SprintInterface.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TFEngine.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TFNativeOp.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TFNetworkLayer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TFNetworkRecLayer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TFNetworkSigProcLayer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TFUpdater.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TFUtil.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TF_determinism.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TaskSystem.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TaskSystem_SharedMem.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_TranslationDataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_datasets_huggingface.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_demos.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_fork_exec.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_hdf_dump.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_array.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_attention.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_base.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_cond.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_const.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_container.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_conv.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_decoder_transformer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_encoder_conformer.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_gradient.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_label_smoothing.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_loop.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_math.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_normalization.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_piecewise_linear.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_rec.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_reduce.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_rf_signal.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_tensor.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_threading.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_tools.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_torch_dataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_torch_engine.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_torch_frontend.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_torch_internal_frontend.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/test_torch_util.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tests/torch_utils.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/_setup_returnn_env.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/analyze-dataset-batches.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/bliss-collect-seq-lens.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/bliss-dump-text.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/bliss-get-segment-names.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/bliss-to-ogg-zip.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/bpe-create-lexicon.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/calculate-word-error-rate.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/cleanup-old-models.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/collect-orth-symbols.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/collect-words.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/compile_native_op.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/compile_tf_graph.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/debug-dump-search-scores.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/debug-plot-search-scores.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/dump-dataset-raw-strings.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/dump-dataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/dump-forward-stats.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/dump-forward.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/dump-network-json.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/dump-pickle.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/extract_state_tying_from_dataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/file-cache.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/get-attention-weights.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/get-best-model-epoch.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/hdf_dump.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/hdf_dump_translation_dataset.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/import-blocks-mt-model.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/import-t2t-mt-model.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/.gitignore +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/Makefile +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/example/README.md +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/example/libs_list +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/example/state_vars_list +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/example/tensor_names_list +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/file.h +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/main.cc +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/rescorer.h +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/vocabulary.cc +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/lattice_rescorer/vocabulary.h +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/tf_avg_checkpoints.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/tf_inspect_checkpoint.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/tf_inspect_summary_log.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/torch_avg_checkpoints.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/torch_export_to_onnx.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/torch_inspect_checkpoint.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/torch_inspect_checkpoint_and_opt.py +0 -0
- {returnn-1.20251023.135024 → returnn-1.20251118.160612}/tools/torch_scale_tuning.py +0 -0
|
@@ -801,7 +801,7 @@ class SubProcCopyGlobalConfigPreInitFunc:
|
|
|
801
801
|
from returnn.log import log
|
|
802
802
|
from returnn import __old_mod_loader__
|
|
803
803
|
|
|
804
|
-
better_exchook.
|
|
804
|
+
better_exchook.setup_all()
|
|
805
805
|
__old_mod_loader__.disable_lazy_mod_loads()
|
|
806
806
|
|
|
807
807
|
if self.global_config:
|
|
@@ -1164,11 +1164,9 @@ class StaticDataset(CachedDataset2):
|
|
|
1164
1164
|
"""supports sorting"""
|
|
1165
1165
|
return True
|
|
1166
1166
|
|
|
1167
|
-
def _collect_single_seq(self, seq_idx):
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
:rtype: DatasetSeq
|
|
1171
|
-
"""
|
|
1167
|
+
def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
|
|
1168
|
+
if seq_idx >= len(self._seq_order):
|
|
1169
|
+
return None
|
|
1172
1170
|
corpus_seq_idx = self._seq_order[seq_idx]
|
|
1173
1171
|
data = self.data[corpus_seq_idx]
|
|
1174
1172
|
return DatasetSeq(
|
|
@@ -694,6 +694,26 @@ class LmDataset(CachedDataset2):
|
|
|
694
694
|
self.next_seq_idx = seq_idx + 1
|
|
695
695
|
return DatasetSeq(seq_idx=seq_idx, features=data, targets=targets, seq_tag=seq_tag)
|
|
696
696
|
|
|
697
|
+
def finish_epoch(self, *, free_resources: bool = False):
|
|
698
|
+
"""finish epoch"""
|
|
699
|
+
super().finish_epoch(free_resources=free_resources)
|
|
700
|
+
|
|
701
|
+
if free_resources:
|
|
702
|
+
self._orths_offsets_and_lens = None
|
|
703
|
+
if self._orth_mmaps is not None:
|
|
704
|
+
for m in self._orth_mmaps:
|
|
705
|
+
if m is not None:
|
|
706
|
+
m.close()
|
|
707
|
+
self._orth_mmaps = None
|
|
708
|
+
if self._orth_files is not None:
|
|
709
|
+
for f in self._orth_files:
|
|
710
|
+
if f is not None:
|
|
711
|
+
f.close()
|
|
712
|
+
self._orth_files = None
|
|
713
|
+
|
|
714
|
+
self._seq_list = None
|
|
715
|
+
self._seq_index_by_tag = None
|
|
716
|
+
|
|
697
717
|
|
|
698
718
|
def _is_bliss(filename):
|
|
699
719
|
"""
|
|
@@ -253,22 +253,12 @@ class MetaDataset(CachedDataset2):
|
|
|
253
253
|
}
|
|
254
254
|
|
|
255
255
|
self._seq_list_file = seq_list_file
|
|
256
|
-
self.seq_list_original =
|
|
257
|
-
self.
|
|
258
|
-
for key in self.dataset_keys:
|
|
259
|
-
assert len(self.seq_list_original[key]) == self.num_total_seqs
|
|
260
|
-
|
|
261
|
-
self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original[self.default_dataset_key])}
|
|
256
|
+
self.seq_list_original: Optional[Dict[str, List[str]]] = None
|
|
257
|
+
self.tag_idx: Optional[Dict[str, int]] = None
|
|
262
258
|
|
|
263
259
|
self._seq_lens: Optional[Dict[str, NumbersDict]] = None
|
|
264
260
|
self._num_timesteps: Optional[NumbersDict] = None
|
|
265
261
|
self._seq_lens_file = seq_lens_file
|
|
266
|
-
if seq_lens_file:
|
|
267
|
-
seq_lens = load_json(filename=seq_lens_file)
|
|
268
|
-
assert isinstance(seq_lens, dict)
|
|
269
|
-
# dict[str,NumbersDict], seq-tag -> data-key -> len
|
|
270
|
-
self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()}
|
|
271
|
-
self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original[self.default_dataset_key]])
|
|
272
262
|
|
|
273
263
|
if data_dims:
|
|
274
264
|
data_dims = convert_data_dims(data_dims)
|
|
@@ -290,19 +280,20 @@ class MetaDataset(CachedDataset2):
|
|
|
290
280
|
self.num_outputs = self.data_dims
|
|
291
281
|
|
|
292
282
|
self.orig_seq_order_is_initialized = False
|
|
283
|
+
self._current_seq_order: List[int] = []
|
|
293
284
|
self.seq_list_ordered: Optional[Dict[str, List[str]]] = None
|
|
294
285
|
|
|
295
|
-
def
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
if not seq_list_file:
|
|
286
|
+
def _lazy_init_seq_list(self):
|
|
287
|
+
if self.seq_list_original is not None:
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
if not self._seq_list_file:
|
|
301
291
|
# We create a sequence list from all the sequences of the default dataset
|
|
302
292
|
# and hope that it also applies to the
|
|
303
293
|
# other datasets.
|
|
304
294
|
# This can only work if all datasets have the same tag format and the sequences in the other
|
|
305
295
|
# datasets are a subset of those in the default dataset.
|
|
296
|
+
# (But the order does not matter.)
|
|
306
297
|
default_dataset = self.datasets[self.default_dataset_key]
|
|
307
298
|
assert isinstance(default_dataset, Dataset)
|
|
308
299
|
print(
|
|
@@ -349,17 +340,18 @@ class MetaDataset(CachedDataset2):
|
|
|
349
340
|
break # only print one
|
|
350
341
|
del seq_list_set
|
|
351
342
|
raise Exception("Dataset %r is missing seqs." % key)
|
|
352
|
-
elif isinstance(
|
|
353
|
-
seq_list = Dataset._load_seq_list_file(
|
|
354
|
-
elif isinstance(
|
|
343
|
+
elif isinstance(self._seq_list_file, str):
|
|
344
|
+
seq_list = Dataset._load_seq_list_file(self._seq_list_file, expect_list=False)
|
|
345
|
+
elif isinstance(self._seq_list_file, dict):
|
|
355
346
|
for key in self.dataset_keys:
|
|
356
|
-
if key not in
|
|
347
|
+
if key not in self._seq_list_file:
|
|
357
348
|
raise ValueError(f"seq_list_file does not contain all datasets, missing {key}")
|
|
358
|
-
seq_list = {key: Dataset._load_seq_list_file(
|
|
349
|
+
seq_list = {key: Dataset._load_seq_list_file(self._seq_list_file[key]) for key in self.dataset_keys}
|
|
359
350
|
else:
|
|
360
|
-
raise TypeError(f"unexpected seq_list_file type {type(
|
|
351
|
+
raise TypeError(f"unexpected seq_list_file type {type(self._seq_list_file)}")
|
|
361
352
|
|
|
362
353
|
if isinstance(seq_list, list):
|
|
354
|
+
# Use same seq list for all datasets
|
|
363
355
|
seq_list = {key: seq_list for key in self.dataset_keys}
|
|
364
356
|
elif isinstance(seq_list, dict):
|
|
365
357
|
for key in self.dataset_keys:
|
|
@@ -368,10 +360,29 @@ class MetaDataset(CachedDataset2):
|
|
|
368
360
|
else:
|
|
369
361
|
raise TypeError(f"unexpected seq_list type {type(seq_list)}")
|
|
370
362
|
|
|
371
|
-
|
|
363
|
+
for key in self.dataset_keys:
|
|
364
|
+
assert len(seq_list[key]) == len(seq_list[self.default_dataset_key])
|
|
365
|
+
|
|
366
|
+
self.seq_list_original = seq_list
|
|
367
|
+
|
|
368
|
+
def _lazy_init_tag_idx(self):
|
|
369
|
+
if self.tag_idx is not None:
|
|
370
|
+
return
|
|
371
|
+
self._lazy_init_seq_list()
|
|
372
|
+
self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original[self.default_dataset_key])}
|
|
373
|
+
|
|
374
|
+
def _lazy_init_seq_lens(self):
|
|
375
|
+
if self._seq_lens is not None:
|
|
376
|
+
return
|
|
377
|
+
assert self._seq_lens_file
|
|
378
|
+
seq_lens = load_json(filename=self._seq_lens_file)
|
|
379
|
+
assert isinstance(seq_lens, dict)
|
|
380
|
+
# dict[str,NumbersDict], seq-tag -> data-key -> len
|
|
381
|
+
self._seq_lens = {tag: NumbersDict(lens) for (tag, lens) in seq_lens.items()}
|
|
372
382
|
|
|
373
383
|
def _get_dataset_seq_length(self, seq_idx: int):
|
|
374
384
|
if not self.orig_seq_order_is_initialized:
|
|
385
|
+
self._lazy_init_seq_list()
|
|
375
386
|
# To use get_seq_length() we first have to init the sequence order once in original order.
|
|
376
387
|
# If sequence lengths are not needed by get_seq_order_for_epoch this is never executed.
|
|
377
388
|
self.datasets[self.default_dataset_key].init_seq_order(
|
|
@@ -379,6 +390,9 @@ class MetaDataset(CachedDataset2):
|
|
|
379
390
|
)
|
|
380
391
|
self.orig_seq_order_is_initialized = True
|
|
381
392
|
|
|
393
|
+
# Warning: This is not correct in the general case.
|
|
394
|
+
# get_seq_length needs to have load_seqs called beforehand per API contract.
|
|
395
|
+
# For some datasets, it might anyway work.
|
|
382
396
|
return self.datasets[self.default_dataset_key].get_seq_length(seq_idx)["data"]
|
|
383
397
|
|
|
384
398
|
def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
|
|
@@ -392,6 +406,7 @@ class MetaDataset(CachedDataset2):
|
|
|
392
406
|
self.epoch is None
|
|
393
407
|
or self.epoch != epoch
|
|
394
408
|
or self.seq_list_ordered is None
|
|
409
|
+
or not self._current_seq_order
|
|
395
410
|
or seq_list is not None
|
|
396
411
|
or seq_order is not None
|
|
397
412
|
or self.expected_load_seq_start > 0
|
|
@@ -401,16 +416,17 @@ class MetaDataset(CachedDataset2):
|
|
|
401
416
|
# This is called via initialize() with epoch=None, just to init some other things.
|
|
402
417
|
# We are not expected to have prepared any real epoch here.
|
|
403
418
|
self._num_seqs = 0
|
|
419
|
+
self._current_seq_order = []
|
|
404
420
|
return True
|
|
405
421
|
|
|
406
422
|
if not need_reinit:
|
|
407
|
-
self._num_seqs = len(self.seq_list_ordered[self.default_dataset_key])
|
|
408
423
|
return False
|
|
409
424
|
|
|
410
425
|
seq_order_dataset = None
|
|
411
426
|
if seq_order is not None:
|
|
412
427
|
seq_index = seq_order
|
|
413
428
|
elif seq_list is not None:
|
|
429
|
+
self._lazy_init_tag_idx()
|
|
414
430
|
seq_index = [self.tag_idx[tag] for tag in seq_list]
|
|
415
431
|
elif self.seq_order_control_dataset:
|
|
416
432
|
seq_order_dataset = self.datasets[self.seq_order_control_dataset]
|
|
@@ -418,13 +434,15 @@ class MetaDataset(CachedDataset2):
|
|
|
418
434
|
seq_order_dataset.init_seq_order(epoch=epoch)
|
|
419
435
|
seq_index = seq_order_dataset.get_current_seq_order()
|
|
420
436
|
else:
|
|
421
|
-
if self.
|
|
437
|
+
if self._seq_lens_file:
|
|
422
438
|
|
|
423
439
|
def get_seq_len(s):
|
|
424
440
|
"""
|
|
425
441
|
:param int s:
|
|
426
442
|
:rtype: int
|
|
427
443
|
"""
|
|
444
|
+
self._lazy_init_seq_list()
|
|
445
|
+
self._lazy_init_seq_lens()
|
|
428
446
|
return self._seq_lens[self.seq_list_original[self.default_dataset_key][s]]["data"]
|
|
429
447
|
|
|
430
448
|
elif self._seq_order_seq_lens_file:
|
|
@@ -432,8 +450,10 @@ class MetaDataset(CachedDataset2):
|
|
|
432
450
|
else:
|
|
433
451
|
self.orig_seq_order_is_initialized = False
|
|
434
452
|
get_seq_len = self._get_dataset_seq_length
|
|
435
|
-
seq_index = self.get_seq_order_for_epoch(epoch, self.
|
|
453
|
+
seq_index = self.get_seq_order_for_epoch(epoch, self.get_total_num_seqs(), get_seq_len)
|
|
436
454
|
self._num_seqs = len(seq_index)
|
|
455
|
+
self._current_seq_order = seq_index
|
|
456
|
+
self._lazy_init_seq_list()
|
|
437
457
|
self.seq_list_ordered = {key: [ls[s] for s in seq_index] for (key, ls) in self.seq_list_original.items()}
|
|
438
458
|
|
|
439
459
|
for dataset_key, dataset in self.datasets.items():
|
|
@@ -447,7 +467,7 @@ class MetaDataset(CachedDataset2):
|
|
|
447
467
|
"""supports sorting"""
|
|
448
468
|
if self.seq_order_control_dataset:
|
|
449
469
|
return self.datasets[self.seq_order_control_dataset].supports_seq_order_sorting()
|
|
450
|
-
if self.
|
|
470
|
+
if self._seq_lens_file or self._seq_order_seq_lens_file:
|
|
451
471
|
return True
|
|
452
472
|
return False
|
|
453
473
|
|
|
@@ -464,20 +484,40 @@ class MetaDataset(CachedDataset2):
|
|
|
464
484
|
:return: current seq order for the current epoch, after self.init_seq_order was called.
|
|
465
485
|
:rtype: list[int]
|
|
466
486
|
"""
|
|
467
|
-
return
|
|
487
|
+
return self._current_seq_order
|
|
468
488
|
|
|
469
489
|
def get_all_tags(self):
|
|
470
490
|
"""
|
|
471
491
|
:return: list of all seq tags, of the whole dataset, without partition epoch
|
|
472
492
|
:rtype: list[str]
|
|
473
493
|
"""
|
|
494
|
+
if self._seq_list_file is None:
|
|
495
|
+
return self.datasets[self.default_dataset_key].get_all_tags()
|
|
496
|
+
self._lazy_init_seq_list()
|
|
497
|
+
assert self.seq_list_original is not None
|
|
474
498
|
return self.seq_list_original[self.default_dataset_key]
|
|
475
499
|
|
|
476
500
|
def get_total_num_seqs(self, *, fast: bool = False) -> int:
|
|
477
501
|
"""
|
|
502
|
+
:param fast: if True, might raise an exception if not possible to get fast.
|
|
478
503
|
:return: total number of seqs, without partition epoch
|
|
479
504
|
"""
|
|
480
|
-
|
|
505
|
+
if self._seq_list_file is None:
|
|
506
|
+
return self.datasets[self.default_dataset_key].get_total_num_seqs(fast=fast)
|
|
507
|
+
if fast and self.seq_list_original is None:
|
|
508
|
+
raise OptionalNotImplementedError(f"{self} get_total_num_seqs, seq list not loaded yet")
|
|
509
|
+
self._lazy_init_seq_list()
|
|
510
|
+
assert self.seq_list_original is not None
|
|
511
|
+
return len(self.seq_list_original[self.default_dataset_key])
|
|
512
|
+
|
|
513
|
+
def get_num_timesteps(self):
|
|
514
|
+
"""num timesteps"""
|
|
515
|
+
if self._num_timesteps is None and self._seq_lens_file:
|
|
516
|
+
self._lazy_init_seq_lens()
|
|
517
|
+
self._num_timesteps = sum([self._seq_lens[s] for s in self.get_all_tags()], start=NumbersDict())
|
|
518
|
+
if self._seq_list_file is None:
|
|
519
|
+
return self.datasets[self.default_dataset_key].get_num_timesteps()
|
|
520
|
+
return super().get_num_timesteps()
|
|
481
521
|
|
|
482
522
|
def finish_epoch(self, *, free_resources: bool = False):
|
|
483
523
|
"""
|
|
@@ -503,8 +543,9 @@ class MetaDataset(CachedDataset2):
|
|
|
503
543
|
if start_ < end:
|
|
504
544
|
for dataset_key in self.dataset_keys:
|
|
505
545
|
self.datasets[dataset_key].load_seqs(start_, end)
|
|
506
|
-
|
|
507
|
-
|
|
546
|
+
if self.seq_list_ordered is not None:
|
|
547
|
+
for seq_idx in range(start_, end):
|
|
548
|
+
self._check_dataset_seq(dataset_key, seq_idx)
|
|
508
549
|
super(MetaDataset, self)._load_seqs(start=start, end=end)
|
|
509
550
|
|
|
510
551
|
def _check_dataset_seq(self, dataset_key, seq_idx):
|
|
@@ -531,7 +572,7 @@ class MetaDataset(CachedDataset2):
|
|
|
531
572
|
:type seq_idx: int
|
|
532
573
|
:rtype: DatasetSeq
|
|
533
574
|
"""
|
|
534
|
-
seq_tag = self.
|
|
575
|
+
seq_tag = self.get_tag(seq_idx)
|
|
535
576
|
features = {data_key: self._get_data(seq_idx, data_key) for data_key in self.data_keys}
|
|
536
577
|
return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)
|
|
537
578
|
|
|
@@ -540,8 +581,9 @@ class MetaDataset(CachedDataset2):
|
|
|
540
581
|
:param int sorted_seq_idx:
|
|
541
582
|
:rtype: NumbersDict
|
|
542
583
|
"""
|
|
543
|
-
if self.
|
|
544
|
-
|
|
584
|
+
if self._seq_lens_file:
|
|
585
|
+
self._lazy_init_seq_lens()
|
|
586
|
+
return self._seq_lens[self.get_tag(sorted_seq_idx)]
|
|
545
587
|
return super(MetaDataset, self).get_seq_length(sorted_seq_idx)
|
|
546
588
|
|
|
547
589
|
def get_tag(self, sorted_seq_idx):
|
|
@@ -549,7 +591,10 @@ class MetaDataset(CachedDataset2):
|
|
|
549
591
|
:param int sorted_seq_idx:
|
|
550
592
|
:rtype: str
|
|
551
593
|
"""
|
|
552
|
-
|
|
594
|
+
if self.seq_list_ordered is not None:
|
|
595
|
+
return self.seq_list_ordered[self.default_dataset_key][sorted_seq_idx]
|
|
596
|
+
else:
|
|
597
|
+
return self.datasets[self.default_dataset_key].get_tag(sorted_seq_idx)
|
|
553
598
|
|
|
554
599
|
def get_complete_frac(self, sorted_seq_idx: int, **kwargs) -> Optional[float]:
|
|
555
600
|
"""
|
|
@@ -961,10 +1006,10 @@ class CombinedDataset(CachedDataset2):
|
|
|
961
1006
|
super(CombinedDataset, self).__init__(**kwargs)
|
|
962
1007
|
assert self.shuffle_frames_of_nseqs == 0 # not implemented. anyway only for non-recurrent nets
|
|
963
1008
|
|
|
1009
|
+
self.data_map = data_map
|
|
964
1010
|
self.dataset_keys = set([m[0] for m in data_map.keys()]) # type: typing.Set[str]
|
|
965
1011
|
self.dataset_idx2key_map = dict(enumerate(sorted(self.dataset_keys))) # idx -> dataset-key
|
|
966
1012
|
self.data_keys = set(data_map.values()) # type: typing.Set[str]
|
|
967
|
-
assert "data" in self.data_keys
|
|
968
1013
|
self.target_list = sorted(self.data_keys - {"data"})
|
|
969
1014
|
|
|
970
1015
|
# Build target lookup table that maps from dataset_key and data_key (data key used by CombinedDataset)
|
|
@@ -994,8 +1039,7 @@ class CombinedDataset(CachedDataset2):
|
|
|
994
1039
|
if data_dims:
|
|
995
1040
|
data_dims = convert_data_dims(data_dims)
|
|
996
1041
|
self.data_dims = data_dims
|
|
997
|
-
|
|
998
|
-
for key in self.target_list:
|
|
1042
|
+
for key in self.data_keys:
|
|
999
1043
|
assert key in data_dims
|
|
1000
1044
|
else:
|
|
1001
1045
|
self.data_dims = {}
|
|
@@ -1009,7 +1053,7 @@ class CombinedDataset(CachedDataset2):
|
|
|
1009
1053
|
if dataset_data_key in dataset.labels:
|
|
1010
1054
|
self.labels[data_key] = dataset.labels[dataset_data_key]
|
|
1011
1055
|
|
|
1012
|
-
self.num_inputs = self.data_dims["data"][0]
|
|
1056
|
+
self.num_inputs = self.data_dims["data"][0] if "data" in self.data_dims else 0
|
|
1013
1057
|
self.num_outputs = self.data_dims
|
|
1014
1058
|
|
|
1015
1059
|
self.data_dtypes = {
|
|
@@ -1019,6 +1063,9 @@ class CombinedDataset(CachedDataset2):
|
|
|
1019
1063
|
|
|
1020
1064
|
self.dataset_seq_idx_boundaries: Optional[List[int]] = None
|
|
1021
1065
|
self.dataset_sorted_seq_idx_list: Optional[List[Tuple[int, int]]] = None
|
|
1066
|
+
self._sub_dataset_cur_loaded_seq_range: Optional[List[Tuple[int, int]]] = None
|
|
1067
|
+
# The usage is about the seqs already covered in dataset_sorted_seq_idx_list,
|
|
1068
|
+
# in case we dynamically build up this list.
|
|
1022
1069
|
self.used_num_seqs_per_subset: Optional[List[int]] = None
|
|
1023
1070
|
|
|
1024
1071
|
def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
|
|
@@ -1030,7 +1077,7 @@ class CombinedDataset(CachedDataset2):
|
|
|
1030
1077
|
"""
|
|
1031
1078
|
|
|
1032
1079
|
assert seq_list is None and seq_order is None, "seq_list and seq_order not supported for %s" % self.__class__
|
|
1033
|
-
need_reinit = self.epoch is None or self.epoch != epoch
|
|
1080
|
+
need_reinit = self.epoch is None or self.epoch != epoch or self.expected_load_seq_start > 0
|
|
1034
1081
|
num_seqs_saved = self._num_seqs
|
|
1035
1082
|
super(CombinedDataset, self).init_seq_order(
|
|
1036
1083
|
epoch=epoch, seq_list=seq_list, seq_order=seq_order
|
|
@@ -1047,13 +1094,15 @@ class CombinedDataset(CachedDataset2):
|
|
|
1047
1094
|
for dataset in self.datasets.values():
|
|
1048
1095
|
dataset.init_seq_order(epoch=epoch)
|
|
1049
1096
|
|
|
1097
|
+
self._sub_dataset_cur_loaded_seq_range = [(0, 0)] * len(self.datasets)
|
|
1098
|
+
|
|
1050
1099
|
# noinspection PyBroadException
|
|
1051
1100
|
try:
|
|
1052
1101
|
total_num_seqs = sum([self.datasets[k].num_seqs for k in sorted(self.datasets.keys())])
|
|
1053
1102
|
except Exception:
|
|
1054
1103
|
total_num_seqs = None
|
|
1055
1104
|
|
|
1056
|
-
if total_num_seqs is not None:
|
|
1105
|
+
if total_num_seqs is not None and self.seq_ordering != "interleave":
|
|
1057
1106
|
self.dataset_seq_idx_boundaries = self._create_dataset_seq_idx_boundaries()
|
|
1058
1107
|
|
|
1059
1108
|
if self.sampling_sizes:
|
|
@@ -1090,7 +1139,7 @@ class CombinedDataset(CachedDataset2):
|
|
|
1090
1139
|
|
|
1091
1140
|
# Re-initialize sequence orders of sub-datasets with created sequence list.
|
|
1092
1141
|
self.used_num_seqs_per_subset = []
|
|
1093
|
-
for dataset_idx, dataset_key in self.dataset_idx2key_map.items():
|
|
1142
|
+
for dataset_idx, dataset_key in sorted(self.dataset_idx2key_map.items()):
|
|
1094
1143
|
assert self.datasets[dataset_key].have_corpus_seq_idx()
|
|
1095
1144
|
self.datasets[dataset_key].init_seq_order(epoch=epoch, seq_order=seq_order_subdatasets[dataset_idx])
|
|
1096
1145
|
self.used_num_seqs_per_subset.append(len(seq_order_subdatasets[dataset_idx]))
|
|
@@ -1098,6 +1147,11 @@ class CombinedDataset(CachedDataset2):
|
|
|
1098
1147
|
else:
|
|
1099
1148
|
self.dataset_sorted_seq_idx_list = [] # We will fill this as we go
|
|
1100
1149
|
self.used_num_seqs_per_subset = [0] * len(self.datasets)
|
|
1150
|
+
self._num_seqs = total_num_seqs
|
|
1151
|
+
|
|
1152
|
+
# These are currently not supported/implemented.
|
|
1153
|
+
# All of these should just be done in the sub-datasets directly.
|
|
1154
|
+
assert self.partition_epoch == 1 and self.repeat_epoch == 1 and self._num_shards == 1
|
|
1101
1155
|
|
|
1102
1156
|
return True
|
|
1103
1157
|
|
|
@@ -1236,13 +1290,34 @@ class CombinedDataset(CachedDataset2):
|
|
|
1236
1290
|
|
|
1237
1291
|
return dataset.get_estimated_seq_length(dataset_seq_idx)
|
|
1238
1292
|
|
|
1239
|
-
def
|
|
1293
|
+
def _sub_dataset_make_cur_loaded(self, dataset_idx: int) -> bool:
|
|
1294
|
+
# Cur meaning for the next sequence to be added to dataset_sorted_seq_idx_list.
|
|
1295
|
+
seq_idx = self.used_num_seqs_per_subset[dataset_idx]
|
|
1296
|
+
cur_start, cur_end = self._sub_dataset_cur_loaded_seq_range[dataset_idx]
|
|
1297
|
+
|
|
1298
|
+
if not self.datasets[self.dataset_idx2key_map[dataset_idx]].is_less_than_num_seqs(seq_idx):
|
|
1299
|
+
return False
|
|
1300
|
+
|
|
1301
|
+
if seq_idx >= cur_end:
|
|
1302
|
+
self._sub_dataset_load_seqs(dataset_idx, cur_start, seq_idx + 1)
|
|
1303
|
+
return True
|
|
1304
|
+
elif seq_idx < cur_start:
|
|
1305
|
+
return False
|
|
1306
|
+
else:
|
|
1307
|
+
return True
|
|
1308
|
+
|
|
1309
|
+
def _expand_dataset_seq_idxs(self, num_values: int) -> bool:
|
|
1240
1310
|
"""
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1311
|
+
Try to extend dataset_sorted_seq_idx_list.
|
|
1312
|
+
We expect that we have reached the end of it.
|
|
1313
|
+
|
|
1314
|
+
:param num_values: Add num_values entries to the dataset-segment-idx mapping table
|
|
1315
|
+
:return: whether we added num_values entries
|
|
1244
1316
|
"""
|
|
1245
|
-
for
|
|
1317
|
+
for _ in range(num_values):
|
|
1318
|
+
for j in range(len(self.datasets)):
|
|
1319
|
+
self._sub_dataset_make_cur_loaded(j)
|
|
1320
|
+
|
|
1246
1321
|
if self.seq_ordering == "default": # i.e. in order
|
|
1247
1322
|
dataset_idx = 0
|
|
1248
1323
|
while dataset_idx < len(self.datasets):
|
|
@@ -1265,6 +1340,32 @@ class CombinedDataset(CachedDataset2):
|
|
|
1265
1340
|
else:
|
|
1266
1341
|
return False # No dataset has remaining data
|
|
1267
1342
|
|
|
1343
|
+
elif self.seq_ordering == "interleave":
|
|
1344
|
+
complete_fracs_and_ds_idx = [
|
|
1345
|
+
(
|
|
1346
|
+
self.datasets[self.dataset_idx2key_map[j]].get_complete_frac(
|
|
1347
|
+
self.used_num_seqs_per_subset[j], allow_only_lr_suitable=True
|
|
1348
|
+
)
|
|
1349
|
+
if self.datasets[self.dataset_idx2key_map[j]].is_less_than_num_seqs(
|
|
1350
|
+
self.used_num_seqs_per_subset[j]
|
|
1351
|
+
)
|
|
1352
|
+
else float("inf"),
|
|
1353
|
+
j,
|
|
1354
|
+
)
|
|
1355
|
+
for j in range(len(self.datasets))
|
|
1356
|
+
]
|
|
1357
|
+
assert all(frac is not None for frac, _ in complete_fracs_and_ds_idx), (
|
|
1358
|
+
f"{self}: Datasets must provide complete frac for interleave,"
|
|
1359
|
+
f" got {complete_fracs_and_ds_idx}, dataset idx2key map {self.dataset_idx2key_map}"
|
|
1360
|
+
)
|
|
1361
|
+
# Sort by complete frac, i.e. datasets with the lowest complete frac first.
|
|
1362
|
+
complete_fracs_and_ds_idx.sort()
|
|
1363
|
+
for complete_frac, dataset_idx in complete_fracs_and_ds_idx:
|
|
1364
|
+
if complete_frac < float("inf"):
|
|
1365
|
+
break
|
|
1366
|
+
else:
|
|
1367
|
+
return False # No dataset has remaining data
|
|
1368
|
+
|
|
1268
1369
|
elif self.seq_ordering == "random_dataset":
|
|
1269
1370
|
while True:
|
|
1270
1371
|
# Build probability table
|
|
@@ -1323,19 +1424,23 @@ class CombinedDataset(CachedDataset2):
|
|
|
1323
1424
|
def _load_seqs(self, start, end):
|
|
1324
1425
|
# If the segment order is not yet known, fix the next few segments
|
|
1325
1426
|
if end > len(self.dataset_sorted_seq_idx_list):
|
|
1326
|
-
self.
|
|
1427
|
+
self._expand_dataset_seq_idxs(end - len(self.dataset_sorted_seq_idx_list))
|
|
1327
1428
|
|
|
1328
1429
|
requested_seqs = self.dataset_sorted_seq_idx_list[start:end]
|
|
1329
1430
|
|
|
1330
1431
|
for dataset_idx in range(len(self.datasets)):
|
|
1331
|
-
dataset = self.datasets[self.dataset_idx2key_map[dataset_idx]]
|
|
1332
1432
|
sub_requested_seqs = [s[1] for s in requested_seqs if s[0] == dataset_idx]
|
|
1333
1433
|
if not sub_requested_seqs:
|
|
1334
1434
|
continue
|
|
1335
1435
|
sub_start, sub_end = min(sub_requested_seqs), max(sub_requested_seqs)
|
|
1336
|
-
|
|
1436
|
+
self._sub_dataset_load_seqs(dataset_idx, sub_start, sub_end + 1)
|
|
1337
1437
|
super(CombinedDataset, self)._load_seqs(start=start, end=end)
|
|
1338
1438
|
|
|
1439
|
+
def _sub_dataset_load_seqs(self, dataset_idx: int, start: int, end: int):
|
|
1440
|
+
self._sub_dataset_cur_loaded_seq_range[dataset_idx] = (start, end)
|
|
1441
|
+
dataset = self.datasets[self.dataset_idx2key_map[dataset_idx]]
|
|
1442
|
+
dataset.load_seqs(start, end)
|
|
1443
|
+
|
|
1339
1444
|
def _get_data(self, dataset_key, dataset_seq_idx, data_key):
|
|
1340
1445
|
"""
|
|
1341
1446
|
:type dataset_seq_idx: int
|
|
@@ -1348,7 +1453,10 @@ class CombinedDataset(CachedDataset2):
|
|
|
1348
1453
|
if dataset_data_key is not None:
|
|
1349
1454
|
return dataset.get_data(dataset_seq_idx, dataset_data_key)
|
|
1350
1455
|
else:
|
|
1351
|
-
|
|
1456
|
+
shape: List[int] = [0] * self.num_outputs[data_key][1]
|
|
1457
|
+
if shape and not self.is_data_sparse(data_key):
|
|
1458
|
+
shape[-1] = self.get_data_dim(data_key)
|
|
1459
|
+
return numpy.zeros(shape, dtype=self.data_dtypes[data_key])
|
|
1352
1460
|
|
|
1353
1461
|
def _collect_single_seq(self, seq_idx):
|
|
1354
1462
|
"""
|
|
@@ -1362,19 +1470,30 @@ class CombinedDataset(CachedDataset2):
|
|
|
1362
1470
|
dataset = self.datasets[dataset_key]
|
|
1363
1471
|
|
|
1364
1472
|
seq_tag = dataset.get_tag(dataset_seq_idx)
|
|
1365
|
-
features = self._get_data(dataset_key, dataset_seq_idx,
|
|
1366
|
-
|
|
1367
|
-
|
|
1473
|
+
features = {key: self._get_data(dataset_key, dataset_seq_idx, key) for key in self.data_keys}
|
|
1474
|
+
complete_frac = None
|
|
1475
|
+
if self.seq_ordering == "interleave":
|
|
1476
|
+
# In the interleave case, by design, this should be monotonically increasing,
|
|
1477
|
+
# as per how we select the next seq in _expand_dataset_seq_idxs.
|
|
1478
|
+
complete_frac = dataset.get_complete_frac(dataset_seq_idx, allow_only_lr_suitable=True)
|
|
1479
|
+
# In other cases, complete_frac is not so straightforward.
|
|
1480
|
+
# In the case that the total num seqs is known, then it's anyway not necessary.
|
|
1481
|
+
return DatasetSeq(seq_idx=seq_idx, complete_frac=complete_frac, seq_tag=seq_tag, features=features)
|
|
1368
1482
|
|
|
1369
|
-
def is_less_than_num_seqs(self, n):
|
|
1483
|
+
def is_less_than_num_seqs(self, n: int) -> bool:
|
|
1370
1484
|
"""
|
|
1371
|
-
:param
|
|
1372
|
-
:rtype: bool
|
|
1485
|
+
:param n:
|
|
1373
1486
|
"""
|
|
1374
1487
|
if n < len(self.dataset_sorted_seq_idx_list):
|
|
1375
1488
|
return True
|
|
1376
1489
|
else:
|
|
1377
|
-
return self.
|
|
1490
|
+
return self._expand_dataset_seq_idxs(n - len(self.dataset_sorted_seq_idx_list) + 1)
|
|
1491
|
+
|
|
1492
|
+
def get_data_keys(self) -> List[str]:
|
|
1493
|
+
"""data keys"""
|
|
1494
|
+
if "data" in self.data_keys:
|
|
1495
|
+
return ["data"] + sorted(self.data_keys - {"data"})
|
|
1496
|
+
return sorted(self.data_keys)
|
|
1378
1497
|
|
|
1379
1498
|
def get_target_list(self):
|
|
1380
1499
|
"""
|
|
@@ -11,6 +11,7 @@ __all__ = [
|
|
|
11
11
|
"SentencePieces",
|
|
12
12
|
"CharacterTargets",
|
|
13
13
|
"Utf8ByteTargets",
|
|
14
|
+
"HuggingFaceTokenizer",
|
|
14
15
|
]
|
|
15
16
|
|
|
16
17
|
from typing import Optional, Union, Type, Callable, List, Dict
|
|
@@ -691,3 +692,92 @@ class Utf8ByteTargets(Vocabulary):
|
|
|
691
692
|
assert ((seq >= 0) & (seq < 256)).all(), f"invalid byte value, must be within 0-255: {seq}"
|
|
692
693
|
seq = seq.astype(numpy.uint8)
|
|
693
694
|
return bytearray(seq).decode(encoding="utf8")
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
class HuggingFaceTokenizer(Vocabulary):
|
|
698
|
+
"""
|
|
699
|
+
Uses the `AutoTokenizer` class from the `transformers` package.
|
|
700
|
+
"""
|
|
701
|
+
|
|
702
|
+
def __init__(self, *, huggingface_repo_dir: str):
|
|
703
|
+
"""
|
|
704
|
+
:param str huggingface_repo_dir: the directory containing the `tokenizer_config.json` file.
|
|
705
|
+
"""
|
|
706
|
+
import transformers # noqa
|
|
707
|
+
|
|
708
|
+
# Make sure it is a string. (Could be e.g. Sis Path.)
|
|
709
|
+
huggingface_repo_dir = str(huggingface_repo_dir)
|
|
710
|
+
self._opts = {"huggingface_repo_dir": huggingface_repo_dir}
|
|
711
|
+
self._cache_key = huggingface_repo_dir
|
|
712
|
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(huggingface_repo_dir, trust_remote_code=True)
|
|
713
|
+
super().__init__(
|
|
714
|
+
vocab_file=None,
|
|
715
|
+
seq_postfix=None,
|
|
716
|
+
unknown_label=self.tokenizer.unk_token_id,
|
|
717
|
+
eos_label=self.tokenizer.eos_token_id,
|
|
718
|
+
bos_label=self.tokenizer.bos_token_id,
|
|
719
|
+
pad_label=self.tokenizer.pad_token_id,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
def __repr__(self):
|
|
723
|
+
return "%s(%r)" % (self.__class__.__name__, self._opts)
|
|
724
|
+
|
|
725
|
+
def _parse_vocab(self):
|
|
726
|
+
self.num_labels = len(self.tokenizer)
|
|
727
|
+
# Do not load labels/vocab here. This is not really needed.
|
|
728
|
+
|
|
729
|
+
@property
|
|
730
|
+
def labels(self) -> List[str]:
|
|
731
|
+
"""list of labels"""
|
|
732
|
+
if self._cache_key and self._cache_key in self._cache:
|
|
733
|
+
self._vocab, self._labels = self._cache[self._cache_key]
|
|
734
|
+
assert self.num_labels == len(self._vocab) == len(self._labels)
|
|
735
|
+
else:
|
|
736
|
+
self._labels = [self.tokenizer._convert_id_to_token(i) for i in range(self.num_labels)] # noqa
|
|
737
|
+
self._vocab = {label: i for (i, label) in enumerate(self._labels)}
|
|
738
|
+
if self._cache_key:
|
|
739
|
+
self._cache[self._cache_key] = (self._vocab, self._labels)
|
|
740
|
+
return self._labels
|
|
741
|
+
|
|
742
|
+
def is_id_valid(self, idx: int) -> bool:
|
|
743
|
+
"""
|
|
744
|
+
:param idx:
|
|
745
|
+
"""
|
|
746
|
+
return 0 <= idx < len(self.tokenizer)
|
|
747
|
+
|
|
748
|
+
def id_to_label(self, idx: int, default: Union[str, Type[KeyError], None] = KeyError) -> Optional[str]:
|
|
749
|
+
"""
|
|
750
|
+
:param idx:
|
|
751
|
+
:param default:
|
|
752
|
+
"""
|
|
753
|
+
if default is not KeyError and not self.is_id_valid(idx):
|
|
754
|
+
return default
|
|
755
|
+
return self.tokenizer.convert_ids_to_tokens(idx)
|
|
756
|
+
|
|
757
|
+
def label_to_id(self, label: str, default: Union[int, Type[KeyError], None] = KeyError) -> Optional[int]:
|
|
758
|
+
"""
|
|
759
|
+
:param label:
|
|
760
|
+
:param default:
|
|
761
|
+
"""
|
|
762
|
+
res = self.tokenizer.convert_token_to_id(label)
|
|
763
|
+
if res == self.unknown_label_id or res < 0 or res is None:
|
|
764
|
+
# It could be that the label really is the unknown-label, or it could be that the label is unknown.
|
|
765
|
+
if label == self.id_to_label(self.unknown_label_id):
|
|
766
|
+
return self.unknown_label_id
|
|
767
|
+
if default is KeyError:
|
|
768
|
+
raise KeyError("label %r not found" % label)
|
|
769
|
+
return default
|
|
770
|
+
return res
|
|
771
|
+
|
|
772
|
+
def get_seq(self, sentence: str) -> List[int]:
|
|
773
|
+
"""
|
|
774
|
+
:param sentence: assumed to be seq of vocab entries separated by whitespace
|
|
775
|
+
"""
|
|
776
|
+
return self.tokenizer(sentence)["input_ids"]
|
|
777
|
+
|
|
778
|
+
def get_seq_labels(self, seq):
|
|
779
|
+
"""
|
|
780
|
+
:param list[int]|numpy.ndarray seq: 1D sequence
|
|
781
|
+
:rtype: str
|
|
782
|
+
"""
|
|
783
|
+
return self.tokenizer.decode(seq, skip_special_tokens=True)
|