spikezoo 0.2.2__tar.gz → 0.2.3.2__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo-0.2.3.2/PKG-INFO +263 -0
- spikezoo-0.2.3.2/spikezoo/__init__.py +29 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/models/bsf/rep.py +2 -2
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/nets.py +1 -1
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/models/networks.py +1 -1
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssml/model.py +9 -5
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/metrics/losses.py +1 -1
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/models/networks_STIR.py +16 -9
- {spikezoo-0.2.2/spikezoo/archs/spikeformer/Metrics → spikezoo-0.2.3.2/spikezoo/archs/stir/package_core/build/lib/package_core}/__init__.py +0 -0
- {spikezoo-0.2.2/spikezoo/archs/spikeformer/Model → spikezoo-0.2.3.2/spikezoo/archs/stir/package_core/package_core}/__init__.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/tfi/nets.py +1 -1
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/tfp/nets.py +1 -1
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/dwtnets.py +6 -6
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/datasets/__init__.py +11 -9
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/datasets/base_dataset.py +10 -3
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/datasets/realworld_dataset.py +1 -3
- spikezoo-0.2.2/spikezoo/datasets/reds_small_dataset.py → spikezoo-0.2.3.2/spikezoo/datasets/reds_base_dataset.py +9 -8
- spikezoo-0.2.3.2/spikezoo/datasets/reds_ssir_dataset.py +181 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/datasets/szdata_dataset.py +5 -15
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/datasets/uhsr_dataset.py +4 -3
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/models/__init__.py +8 -6
- spikezoo-0.2.3.2/spikezoo/models/base_model.py +231 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/models/bsf_model.py +11 -3
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/models/spikeclip_model.py +4 -3
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/models/spk2imgnet_model.py +9 -15
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/models/ssir_model.py +4 -6
- spikezoo-0.2.3.2/spikezoo/models/ssml_model.py +60 -0
- spikezoo-0.2.3.2/spikezoo/models/stir_model.py +58 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/models/tfi_model.py +3 -1
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/models/tfp_model.py +4 -2
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/models/wgse_model.py +8 -14
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/pipeline/base_pipeline.py +79 -55
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/pipeline/ensemble_pipeline.py +10 -9
- spikezoo-0.2.3.2/spikezoo/pipeline/train_cfgs.py +89 -0
- spikezoo-0.2.3.2/spikezoo/pipeline/train_pipeline.py +193 -0
- spikezoo-0.2.3.2/spikezoo/utils/optimizer_utils.py +22 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/utils/other_utils.py +31 -6
- spikezoo-0.2.3.2/spikezoo/utils/scheduler_utils.py +25 -0
- spikezoo-0.2.3.2/spikezoo/utils/spike_utils.py +118 -0
- spikezoo-0.2.3.2/spikezoo.egg-info/PKG-INFO +263 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo.egg-info/SOURCES.txt +5 -47
- spikezoo-0.2.2/MANIFEST.in +0 -5
- spikezoo-0.2.2/PKG-INFO +0 -196
- spikezoo-0.2.2/README.md +0 -159
- spikezoo-0.2.2/requirements.txt +0 -18
- spikezoo-0.2.2/setup.py +0 -23
- spikezoo-0.2.2/spikezoo/__init__.py +0 -13
- spikezoo-0.2.2/spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
- spikezoo-0.2.2/spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
- spikezoo-0.2.2/spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
- spikezoo-0.2.2/spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
- spikezoo-0.2.2/spikezoo/archs/spikeformer/EvalResults/readme +0 -1
- spikezoo-0.2.2/spikezoo/archs/spikeformer/LICENSE +0 -21
- spikezoo-0.2.2/spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
- spikezoo-0.2.2/spikezoo/archs/spikeformer/Model/Loss.py +0 -89
- spikezoo-0.2.2/spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
- spikezoo-0.2.2/spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/spikeformer/README.md +0 -30
- spikezoo-0.2.2/spikezoo/archs/spikeformer/evaluate.py +0 -87
- spikezoo-0.2.2/spikezoo/archs/spikeformer/recon_real_data.py +0 -97
- spikezoo-0.2.2/spikezoo/archs/spikeformer/requirements.yml +0 -95
- spikezoo-0.2.2/spikezoo/archs/spikeformer/train.py +0 -173
- spikezoo-0.2.2/spikezoo/archs/spikeformer/utils.py +0 -22
- spikezoo-0.2.2/spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo-0.2.2/spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo-0.2.2/spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo-0.2.2/spikezoo/models/base_model.py +0 -175
- spikezoo-0.2.2/spikezoo/models/spikeformer_model.py +0 -50
- spikezoo-0.2.2/spikezoo/models/ssml_model.py +0 -18
- spikezoo-0.2.2/spikezoo/models/stir_model.py +0 -37
- spikezoo-0.2.2/spikezoo/pipeline/train_pipeline.py +0 -94
- spikezoo-0.2.2/spikezoo/utils/spike_utils.py +0 -86
- spikezoo-0.2.2/spikezoo.egg-info/PKG-INFO +0 -196
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/LICENSE.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/setup.cfg +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/__init__.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/base/nets.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/README.md +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/datasets/datasets.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/datasets/ds_utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/main.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/metrics/psnr.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/metrics/ssim.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/models/bsf/align.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/models/bsf/dsft_convert.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/models/get_model.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/prepare_data/DSFT.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/prepare_data/crop_train.sh +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/prepare_data/crop_val.sh +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/prepare_data/io_utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/requirements.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/test.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/bsf/utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spikeclip/nets.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/.gitignore +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/DCNv2.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/align_arch.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/dataset.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/readme.md +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/test_gen_imgseq.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/train.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/spk2imgnet/utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/README.md +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/configs/SSIR.yml +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/configs/yml_parser.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/datasets/dataset_sreds.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/datasets/ds_utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/losses.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/main.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/metrics/psnr.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/metrics/ssim.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/models/Vgg19.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/models/layers.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/requirements.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/shells/eval_SREDS.sh +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/shells/train_SSIR.sh +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/test.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssir/utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssml/cbam.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssml/res.png +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/ssml/test.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/.git-credentials +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/README.md +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/configs/STIR.yml +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/configs/utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/configs/yml_parser.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/datasets/dataset_sreds.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/datasets/ds_utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/eval_SREDS.sh +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/main.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/metrics/psnr.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/metrics/ssim.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/models/Vgg19.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/models/submodules.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/models/transformer_new.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/convertions.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/disp_netS.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/flow_utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/generic_train_test.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/geometry.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/image_proc.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/linalg.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/losses.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/metrics.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/model_base.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/net_basics.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/resnet.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/transforms.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core/utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/package_core/setup.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/requirements.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/stir/train_STIR.sh +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/README.md +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/dataset.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/demo.png +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/demo.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/eval.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/submodules.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/train.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/transform.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/archs/wgse/weights/demo.png +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/metrics/__init__.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/pipeline/__init__.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/utils/__init__.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/utils/data_utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/utils/img_utils.py +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo.egg-info/dependency_links.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo.egg-info/requires.txt +0 -0
- {spikezoo-0.2.2 → spikezoo-0.2.3.2}/spikezoo.egg-info/top_level.txt +0 -0
@@ -0,0 +1,263 @@
|
|
1
|
+
Metadata-Version: 2.2
|
2
|
+
Name: spikezoo
|
3
|
+
Version: 0.2.3.2
|
4
|
+
Summary: A deep learning toolbox for spike-to-image models.
|
5
|
+
Home-page: https://github.com/chenkang455/Spike-Zoo
|
6
|
+
Author: Kang Chen
|
7
|
+
Author-email: mrchenkang@stu.pku.edu.cn
|
8
|
+
Requires-Python: >=3.7
|
9
|
+
Description-Content-Type: text/markdown
|
10
|
+
License-File: LICENSE.txt
|
11
|
+
Requires-Dist: torch
|
12
|
+
Requires-Dist: requests
|
13
|
+
Requires-Dist: numpy
|
14
|
+
Requires-Dist: tqdm
|
15
|
+
Requires-Dist: scikit-image
|
16
|
+
Requires-Dist: lpips
|
17
|
+
Requires-Dist: pyiqa
|
18
|
+
Requires-Dist: opencv-python
|
19
|
+
Requires-Dist: thop
|
20
|
+
Requires-Dist: pytorch-wavelets
|
21
|
+
Requires-Dist: pytz
|
22
|
+
Requires-Dist: PyWavelets
|
23
|
+
Requires-Dist: pandas
|
24
|
+
Requires-Dist: pillow
|
25
|
+
Requires-Dist: scikit-learn
|
26
|
+
Requires-Dist: scipy
|
27
|
+
Requires-Dist: spikingjelly
|
28
|
+
Requires-Dist: setuptools
|
29
|
+
Dynamic: author
|
30
|
+
Dynamic: author-email
|
31
|
+
Dynamic: description
|
32
|
+
Dynamic: description-content-type
|
33
|
+
Dynamic: home-page
|
34
|
+
Dynamic: requires-dist
|
35
|
+
Dynamic: requires-python
|
36
|
+
Dynamic: summary
|
37
|
+
|
38
|
+
<p align="center">
|
39
|
+
<br>
|
40
|
+
<img src="imgs/spike-zoo.png" width="500"/>
|
41
|
+
<br>
|
42
|
+
<p>
|
43
|
+
|
44
|
+
<h5 align="center">
|
45
|
+
|
46
|
+
[](https://github.com/chenkang455/Spike-Zoo/stargazers) [](https://github.com/chenkang455/Spike-Zoo/issues) <a href="https://badge.fury.io/py/spikezoo"><img src="https://badge.fury.io/py/spikezoo.svg" alt="PyPI version"></a> [](https://github.com/chenkang455/Spike-Zoo)
|
47
|
+
|
48
|
+
<p>
|
49
|
+
|
50
|
+
<!-- <h2 align="center">
|
51
|
+
<a href="">⚡Spike-Zoo: A Toolbox for Spike-to-Image Reconstruction
|
52
|
+
</a>
|
53
|
+
</h2> -->
|
54
|
+
|
55
|
+
## 📖 About
|
56
|
+
⚡Spike-Zoo is the go-to library for state-of-the-art pretrained **spike-to-image** models designed to reconstruct images from spike streams. Whether you're looking for a simple inference solution or aiming to train your own spike-to-image models, ⚡Spike-Zoo is a modular toolbox that supports both, with key features including:
|
57
|
+
|
58
|
+
- Fast inference with pre-trained models.
|
59
|
+
- Training support for custom-designed spike-to-image models.
|
60
|
+
- Specialized functions for processing spike data.
|
61
|
+
|
62
|
+
|
63
|
+
|
64
|
+
## 🚩 Updates/Changelog
|
65
|
+
* **25-02-02:** Release the `Spike-Zoo v0.2` code, which supports more methods, provide more usages like training your method from scratch.
|
66
|
+
* **24-07-19:** Release the `Spike-Zoo v0.1` code for base evaluation of SOTA methods.
|
67
|
+
|
68
|
+
## 🍾 Quick Start
|
69
|
+
### 1. Installation
|
70
|
+
For users focused on **utilizing pretrained models for spike-to-image conversion**, we recommend installing SpikeZoo using one of the following methods:
|
71
|
+
|
72
|
+
* Install the last stable version `0.2.3` from PyPI:
|
73
|
+
```
|
74
|
+
pip install spikezoo
|
75
|
+
```
|
76
|
+
* Install the latest developing version `0.2.3` from the source code :
|
77
|
+
```
|
78
|
+
git clone https://github.com/chenkang455/Spike-Zoo
|
79
|
+
cd Spike-Zoo
|
80
|
+
python setup.py install
|
81
|
+
```
|
82
|
+
|
83
|
+
For users interested in **training their own spike-to-image model based on our framework**, we recommend cloning the repository and modifying the related code directly.
|
84
|
+
```
|
85
|
+
git clone https://github.com/chenkang455/Spike-Zoo
|
86
|
+
cd Spike-Zoo
|
87
|
+
python setup.py develop
|
88
|
+
```
|
89
|
+
|
90
|
+
### 2. Inference
|
91
|
+
Reconstructing images from the spike is super easy with Spike-Zoo. Try the following code of the single model:
|
92
|
+
``` python
|
93
|
+
from spikezoo.pipeline import Pipeline, PipelineConfig
|
94
|
+
import spikezoo as sz
|
95
|
+
pipeline = Pipeline(
|
96
|
+
cfg=PipelineConfig(save_folder="results",version="v023"),
|
97
|
+
model_cfg=sz.METHOD.BASE,
|
98
|
+
dataset_cfg=sz.DATASET.BASE
|
99
|
+
)
|
100
|
+
```
|
101
|
+
You can also run multiple models at once by changing the pipeline (version parameter corresponds to our released different versions in [Releases](https://github.com/chenkang455/Spike-Zoo/releases)):
|
102
|
+
``` python
|
103
|
+
import spikezoo as sz
|
104
|
+
from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig
|
105
|
+
pipeline = EnsemblePipeline(
|
106
|
+
cfg=EnsemblePipelineConfig(save_folder="results",version="v023"),
|
107
|
+
model_cfg_list=[
|
108
|
+
sz.METHOD.BASE,sz.METHOD.TFP,sz.METHOD.TFI,sz.METHOD.SPK2IMGNET,sz.METHOD.WGSE,
|
109
|
+
sz.METHOD.SSML,sz.METHOD.BSF,sz.METHOD.STIR,sz.METHOD.SPIKECLIP,sz.METHOD.SSIR],
|
110
|
+
dataset_cfg=sz.DATASET.BASE,
|
111
|
+
)
|
112
|
+
```
|
113
|
+
Having established our pipelines, we provide following functions to enjoy these spike-to-image models.
|
114
|
+
|
115
|
+
* I. Obtain the restoration metric and save the recovered image from the given spike:
|
116
|
+
``` python
|
117
|
+
# 1. spike-to-image from the given dataset
|
118
|
+
pipeline.infer_from_dataset(idx = 0)
|
119
|
+
|
120
|
+
# 2. spike-to-image from the given .dat file
|
121
|
+
pipeline.infer_from_file(file_path = 'data/scissor.dat',width = 400,height=250)
|
122
|
+
|
123
|
+
# 3. spike-to-image from the given spike
|
124
|
+
import spikezoo as sz
|
125
|
+
spike = sz.load_vidar_dat("data/scissor.dat",width = 400,height = 250)
|
126
|
+
pipeline.infer_from_spk(spike)
|
127
|
+
```
|
128
|
+
|
129
|
+
|
130
|
+
* II. Save all images from the given dataset.
|
131
|
+
``` python
|
132
|
+
pipeline.save_imgs_from_dataset()
|
133
|
+
```
|
134
|
+
|
135
|
+
* III. Calculate the metrics for the specified dataset.
|
136
|
+
``` python
|
137
|
+
pipeline.cal_metrics()
|
138
|
+
```
|
139
|
+
|
140
|
+
* IV. Calculate the parameters (params,flops,latency) based on the established pipeline.
|
141
|
+
``` python
|
142
|
+
pipeline.cal_params()
|
143
|
+
```
|
144
|
+
|
145
|
+
For detailed usage, welcome check [test_single.ipynb](examples/test/test_single.ipynb) and [test_ensemble.ipynb](examples/test/test_ensemble.ipynb).
|
146
|
+
|
147
|
+
### 3. Training
|
148
|
+
We provide a user-friendly code for training our provided `base` model (modified from the `SpikeCLIP`) for the classic `REDS` dataset introduced in `Spk2ImgNet`:
|
149
|
+
``` python
|
150
|
+
from spikezoo.pipeline import TrainPipelineConfig, TrainPipeline
|
151
|
+
from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig
|
152
|
+
from spikezoo.models.base_model import BaseModelConfig
|
153
|
+
pipeline = TrainPipeline(
|
154
|
+
cfg=TrainPipelineConfig(save_folder="results", epochs = 10),
|
155
|
+
dataset_cfg=REDS_BASEConfig(root_dir = "spikezoo/data/REDS_BASE"),
|
156
|
+
model_cfg=BaseModelConfig(),
|
157
|
+
)
|
158
|
+
pipeline.train()
|
159
|
+
```
|
160
|
+
We finish the training with one 4090 GPU in `2 minutes`, achieving `32.8dB` in PSNR and `0.92` in SSIM.
|
161
|
+
|
162
|
+
> 🌟 We encourage users to develop their models with simple modifications to our framework, and the tutorial will be released soon.
|
163
|
+
|
164
|
+
We retrain all supported methods except `SPIKECLIP` on this REDS dataset (training scripts are placed on [examples/train_reds_base](examples/train_reds_base) and evaluation script is placed on [test_REDS_base.py](examples/test/test_REDS_base.py)), with our reported metrics as follows:
|
165
|
+
|
166
|
+
| Method | PSNR | SSIM | LPIPS | NIQE | BRISQUE | PIQE | Params (M) | FLOPs (G) | Latency (ms) |
|
167
|
+
|----------------------|:-------:|:--------:|:---------:|:---------:|:----------:|:-------:|:------------:|:-----------:|:--------------:|
|
168
|
+
| `TFI` | 16.503 | 0.454 | 0.382 | 7.289 | 43.17 | 49.12 | 0.00 | 0.00 | 3.60 |
|
169
|
+
| `TFP` | 24.287 | 0.644 | 0.274 | 8.197 | 48.48 | 38.38 | 0.00 | 0.00 | 0.03 |
|
170
|
+
| `SPIKECLIP` | 21.873 | 0.578 | 0.333 | 7.802 | 42.08 | 54.01 | 0.19 | 23.69 | 1.27 |
|
171
|
+
| `SSIR` | 26.544 | 0.718 | 0.325 | 4.769 | 28.45 | 21.59 | 0.38 | 25.92 | 4.52 |
|
172
|
+
| `SSML` | 33.697 | 0.943 | 0.088 | 4.669 | 32.48 | 37.30 | 2.38 | 386.02 | 244.18 |
|
173
|
+
| `BASE` | 36.589 | 0.965 | 0.034 | 4.393 | 26.16 | 38.43 | 0.18 | 18.04 | 0.40 |
|
174
|
+
| `STIR` | 37.914 | 0.973 | 0.027 | 4.236 | 25.10 | 39.18 | 5.08 | 43.31 | 21.07 |
|
175
|
+
| `WGSE` | 39.036 | 0.978 | 0.023 | 4.231 | 25.76 | 44.11 | 3.81 | 415.26 | 73.62 |
|
176
|
+
| `SPK2IMGNET` | 39.154 | 0.978 | 0.022 | 4.243 | 25.20 | 43.09 | 3.90 | 1000.50 | 123.38 |
|
177
|
+
| `BSF` | 39.576 | 0.979 | 0.019 | 4.139 | 24.93 | 43.03 | 2.47 | 705.23 | 401.50 |
|
178
|
+
|
179
|
+
### 4. Model Usage
|
180
|
+
We also provide a direct interface for users interested in taking the spike-to-image model as a part of their work:
|
181
|
+
|
182
|
+
```python
|
183
|
+
import spikezoo as sz
|
184
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
185
|
+
# input data
|
186
|
+
spike = sz.load_vidar_dat("data/data.dat", width=400, height=250, out_format="tensor")
|
187
|
+
spike = spike[None].cuda()
|
188
|
+
print(f"Input spike shape: {spike.shape}")
|
189
|
+
# net
|
190
|
+
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}))
|
191
|
+
net.build_network(mode = "debug")
|
192
|
+
# process
|
193
|
+
recon_img = net(spike)
|
194
|
+
print(recon_img.shape,recon_img.max(),recon_img.min())
|
195
|
+
```
|
196
|
+
For detailed usage, welcome check [test_model.ipynb](examples/test/test_model.ipynb).
|
197
|
+
|
198
|
+
### 5. Spike Utility
|
199
|
+
#### I. Faster spike loading interface
|
200
|
+
We provide a faster `load_vidar_dat` function implemented with `cpp` (by [@zeal-ye](https://github.com/zeal-ye)):
|
201
|
+
``` python
|
202
|
+
import spikezoo as sz
|
203
|
+
spike = sz.load_vidar_dat("data/scissor.dat",width = 400,height = 250,version='cpp')
|
204
|
+
```
|
205
|
+
🚀 Results on [test_load_dat.py](examples/test_load_dat.py) show that the `cpp` version is more than 10 times faster than the `python` version.
|
206
|
+
|
207
|
+
#### II. Spike simulation pipeline.
|
208
|
+
We provide our overall spike simulation pipeline in [scripts](scripts/), try to modify the config in `run.sh` and run the command to start the simulation process:
|
209
|
+
``` bash
|
210
|
+
bash run.sh
|
211
|
+
```
|
212
|
+
|
213
|
+
#### III. Spike-related functions.
|
214
|
+
For other spike-related functions, welcome check [spike_utils.py](spikezoo/utils/spike_utils.py)
|
215
|
+
|
216
|
+
## 📅 TODO
|
217
|
+
- [x] Support the overall pipeline for spike simulation.
|
218
|
+
- [ ] Provide the tutorials.
|
219
|
+
- [ ] Support more training settings.
|
220
|
+
- [ ] Support more spike-based image reconstruction methods and datasets.
|
221
|
+
|
222
|
+
## 🤗 Supports
|
223
|
+
Run the following code to find our supported models, datasets and metrics:
|
224
|
+
``` python
|
225
|
+
import spikezoo as sz
|
226
|
+
print(sz.METHODS)
|
227
|
+
print(sz.DATASETS)
|
228
|
+
print(sz.METRICS)
|
229
|
+
```
|
230
|
+
**Supported Models:**
|
231
|
+
| Models | Source
|
232
|
+
| ---- | ---- |
|
233
|
+
| `tfp`,`tfi` | Spike camera and its coding methods |
|
234
|
+
| `spk2imgnet` | Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream |
|
235
|
+
| `wgse` | Learning Temporal-Ordered Representation for Spike Streams Based on Discrete Wavelet Transforms |
|
236
|
+
| `ssml` | Self-Supervised Mutual Learning for Dynamic Scene Reconstruction of Spiking Camera |
|
237
|
+
| `ssir` | Spike Camera Image Reconstruction Using Deep Spiking Neural Networks |
|
238
|
+
| `bsf` | Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations |
|
239
|
+
| `stir` | Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras |
|
240
|
+
| `base`,`spikeclip` | Rethinking High-speed Image Reconstruction Framework with Spike Camera |
|
241
|
+
|
242
|
+
**Supported Datasets:**
|
243
|
+
| Datasets | Source
|
244
|
+
| ---- | ---- |
|
245
|
+
| `reds_base` | Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream |
|
246
|
+
| `uhsr` | Recognizing Ultra-High-Speed Moving Objects with Bio-Inspired Spike Camera |
|
247
|
+
| `realworld` | `recVidarReal2019`,`momVidarReal2021` in [SpikeCV](https://github.com/Zyj061/SpikeCV) |
|
248
|
+
| `szdata` | SpikeReveal: Unlocking Temporal Sequences from Real Blurry Inputs with Spike Streams |
|
249
|
+
|
250
|
+
|
251
|
+
## ✨ Acknowledgment
|
252
|
+
Our code is built on the open-source projects of [SpikeCV](https://spikecv.github.io/), [IQA-Pytorch](https://github.com/chaofengc/IQA-PyTorch), [BasicSR](https://github.com/XPixelGroup/BasicSR) and [NeRFStudio](https://github.com/nerfstudio-project/nerfstudio).We appreciate the effort of the contributors to these repositories. Thanks for [@ruizhao26](https://github.com/ruizhao26), [@shiyan_chen](https://github.com/hnmizuho) and [@Leozhangjiyuan](https://github.com/Leozhangjiyuan) for their help in building this project.
|
253
|
+
|
254
|
+
## 📑 Citation
|
255
|
+
If you find our codes helpful to your research, please consider to use the following citation:
|
256
|
+
```
|
257
|
+
@misc{spikezoo,
|
258
|
+
title={{Spike-Zoo}: Spike-Zoo: A Toolbox for Spike-to-Image Reconstruction},
|
259
|
+
author={Kang Chen and Zhiyuan Ye},
|
260
|
+
year={2025},
|
261
|
+
howpublished = "[Online]. Available: \url{https://github.com/chenkang455/Spike-Zoo}"
|
262
|
+
}
|
263
|
+
```
|
@@ -0,0 +1,29 @@
|
|
1
|
+
from .utils.spike_utils import *
|
2
|
+
from .models import model_list
|
3
|
+
from .datasets import dataset_list
|
4
|
+
from .metrics import metric_all_names
|
5
|
+
|
6
|
+
# METHOD NAME DEFINITION
|
7
|
+
METHODS = model_list
|
8
|
+
class METHOD:
|
9
|
+
BASE = "base"
|
10
|
+
TFP = "tfp"
|
11
|
+
TFI = "tfi"
|
12
|
+
SPK2IMGNET = "spk2imgnet"
|
13
|
+
WGSE = "wgse"
|
14
|
+
SSML = "ssml"
|
15
|
+
BSF = "bsf"
|
16
|
+
STIR = "stir"
|
17
|
+
SSIR = "ssir"
|
18
|
+
SPIKECLIP = "spikeclip"
|
19
|
+
|
20
|
+
# DATASET NAME DEFINITION
|
21
|
+
DATASETS = dataset_list
|
22
|
+
class DATASET:
|
23
|
+
BASE = "base"
|
24
|
+
REDS_BASE = "reds_base"
|
25
|
+
REALWORLD = "realworld"
|
26
|
+
UHSR = "uhsr"
|
27
|
+
|
28
|
+
# METRIC NAME DEFINITION
|
29
|
+
METRICS = metric_all_names
|
@@ -8,18 +8,18 @@ from .align import Multi_Granularity_Align
|
|
8
8
|
class BasicModel(nn.Module):
|
9
9
|
def __init__(self):
|
10
10
|
super().__init__()
|
11
|
-
|
11
|
+
|
12
12
|
####################################################################################
|
13
13
|
## Tools functions for neural networks
|
14
14
|
def weight_parameters(self):
|
15
|
-
return [param for name, param in self.named_parameters() if
|
15
|
+
return [param for name, param in self.named_parameters() if "weight" in name]
|
16
16
|
|
17
17
|
def bias_parameters(self):
|
18
|
-
return [param for name, param in self.named_parameters() if
|
18
|
+
return [param for name, param in self.named_parameters() if "bias" in name]
|
19
19
|
|
20
20
|
def num_parameters(self):
|
21
21
|
return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
|
22
|
-
|
22
|
+
|
23
23
|
def init_weights(self):
|
24
24
|
for layer in self.named_modules():
|
25
25
|
if isinstance(layer, nn.Conv2d):
|
@@ -33,12 +33,21 @@ class BasicModel(nn.Module):
|
|
33
33
|
nn.init.constant_(layer.bias, 0)
|
34
34
|
|
35
35
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
36
|
+
from typing import Literal
|
37
|
+
|
38
|
+
|
39
|
+
def split_and_b_cat(x, spike_dim: Literal[41, 61] = 61):
|
40
|
+
if spike_dim == 61:
|
41
|
+
win_r = 10
|
42
|
+
win_step = 10
|
43
|
+
elif spike_dim == 41:
|
44
|
+
win_r = 6
|
45
|
+
win_step = 7
|
46
|
+
x0 = x[:, 0 : 2 * win_r + 1, :, :].clone()
|
47
|
+
x1 = x[:, win_step : win_step + 2 * win_r + 1, :, :].clone()
|
48
|
+
x2 = x[:, 2 * win_step : 2 * win_step + 2 * win_r + 1, :, :].clone()
|
49
|
+
x3 = x[:, 3 * win_step : 3 * win_step + 2 * win_r + 1, :, :].clone()
|
50
|
+
x4 = x[:, 4 * win_step : 4 * win_step + 2 * win_r + 1, :, :].clone()
|
42
51
|
return torch.cat([x0, x1, x2, x3, x4], dim=0)
|
43
52
|
|
44
53
|
|
@@ -61,39 +70,42 @@ class Encoder(nn.Module):
|
|
61
70
|
x = self.act(conv(x) + x)
|
62
71
|
return x
|
63
72
|
|
73
|
+
|
64
74
|
##########################################################################
|
65
75
|
class BSF(BasicModel):
|
66
|
-
def __init__(self, act=nn.ReLU()):
|
76
|
+
def __init__(self, spike_dim=61, act=nn.ReLU()):
|
67
77
|
super().__init__()
|
78
|
+
self.spike_dim = spike_dim
|
68
79
|
self.offset_groups = 4
|
69
80
|
self.corr_max_disp = 3
|
70
|
-
|
71
|
-
|
72
|
-
|
81
|
+
if spike_dim == 61:
|
82
|
+
self.rep = MODF(in_dim=21,base_dim=64, act=act)
|
83
|
+
elif spike_dim == 41:
|
84
|
+
self.rep = MODF(in_dim=13,base_dim=64, act=act)
|
73
85
|
self.encoder = Encoder(base_dim=64, layers=4, act=act)
|
74
86
|
|
75
87
|
self.align = Multi_Granularity_Align(base_dim=64, groups=self.offset_groups, act=act, sc=3)
|
76
88
|
|
77
89
|
self.recons = nn.Sequential(
|
78
|
-
nn.Conv2d(64*5, 64*3, kernel_size=3, padding=1),
|
90
|
+
nn.Conv2d(64 * 5, 64 * 3, kernel_size=3, padding=1),
|
79
91
|
act,
|
80
|
-
nn.Conv2d(64*3, 64, kernel_size=3, padding=1),
|
92
|
+
nn.Conv2d(64 * 3, 64, kernel_size=3, padding=1),
|
81
93
|
act,
|
82
94
|
nn.Conv2d(64, 1, kernel_size=3, padding=1),
|
83
95
|
)
|
84
96
|
|
85
97
|
def forward(self, input_dict):
|
86
|
-
dsft_dict = input_dict[
|
87
|
-
dsft11 = dsft_dict[
|
88
|
-
dsft12 = dsft_dict[
|
89
|
-
dsft21 = dsft_dict[
|
90
|
-
dsft22 = dsft_dict[
|
98
|
+
dsft_dict = input_dict["dsft_dict"]
|
99
|
+
dsft11 = dsft_dict["dsft11"]
|
100
|
+
dsft12 = dsft_dict["dsft12"]
|
101
|
+
dsft21 = dsft_dict["dsft21"]
|
102
|
+
dsft22 = dsft_dict["dsft22"]
|
91
103
|
|
92
104
|
dsft_b_cat = {
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
105
|
+
"dsft11": split_and_b_cat(dsft11, self.spike_dim),
|
106
|
+
"dsft12": split_and_b_cat(dsft12, self.spike_dim),
|
107
|
+
"dsft21": split_and_b_cat(dsft21, self.spike_dim),
|
108
|
+
"dsft22": split_and_b_cat(dsft22, self.spike_dim),
|
97
109
|
}
|
98
110
|
|
99
111
|
feat_b_cat = self.rep(dsft_b_cat)
|
@@ -2,11 +2,11 @@ import torch
|
|
2
2
|
import torch.nn as nn
|
3
3
|
|
4
4
|
class MODF(nn.Module):
|
5
|
-
def __init__(self, base_dim=64, act=nn.ReLU()):
|
5
|
+
def __init__(self, in_dim = 21, base_dim=64, act=nn.ReLU()):
|
6
6
|
super().__init__()
|
7
7
|
self.base_dim = base_dim
|
8
8
|
|
9
|
-
self.conv1 = self._make_layer(input_dim=
|
9
|
+
self.conv1 = self._make_layer(input_dim=in_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
|
10
10
|
self.conv_for_others = nn.ModuleList([
|
11
11
|
self._make_layer(input_dim=self.base_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act) for ii in range(3)
|
12
12
|
])
|
@@ -167,7 +167,7 @@ class FusionMaskV1(nn.Module):
|
|
167
167
|
|
168
168
|
# current best model
|
169
169
|
class SpikeNet(nn.Module):
|
170
|
-
def __init__(self, in_channels, features, out_channels, win_r, win_step):
|
170
|
+
def __init__(self, in_channels = 13, features = 64, out_channels = 1, win_r = 6, win_step = 7):
|
171
171
|
super(SpikeNet, self).__init__()
|
172
172
|
self.extractor = FeatureExtractor(
|
173
173
|
in_channels=in_channels,
|
@@ -272,18 +272,22 @@ class BSN(nn.Module):
|
|
272
272
|
diff = W - H
|
273
273
|
x0 = x0[:, :, (diff // 2):(diff // 2 + H), 0:W]
|
274
274
|
|
275
|
-
return x0
|
275
|
+
return x0
|
276
276
|
|
277
277
|
class DoubleNet(nn.Module):
|
278
278
|
def __init__(self):
|
279
279
|
super().__init__()
|
280
280
|
self.nbsn = BSN(n_channels=41, n_output=1,blind=False)
|
281
|
-
|
281
|
+
self.bsn = BSN(n_channels=41, n_output=1,blind=True)
|
282
282
|
|
283
283
|
def forward(self, x):
|
284
|
-
|
285
|
-
|
286
|
-
|
284
|
+
if self.training:
|
285
|
+
bsn_pred = self.bsn(x)
|
286
|
+
nbsn_pred = self.nbsn(x)
|
287
|
+
return bsn_pred,nbsn_pred
|
288
|
+
else:
|
289
|
+
nbsn_pred = self.nbsn(x)
|
290
|
+
return nbsn_pred
|
287
291
|
|
288
292
|
if __name__ == '__main__':
|
289
293
|
a=DoubleNet().cuda()
|
@@ -292,16 +292,21 @@ class STIRDecorder(nn.Module):#second and third levels
|
|
292
292
|
|
293
293
|
##############################Our Model####################################
|
294
294
|
class STIR(BasicModel):
|
295
|
-
def __init__(self, hidd_chs=8, win_r=6, win_step=7):
|
295
|
+
def __init__(self, spike_dim = 61,hidd_chs=8, win_r=6, win_step=7):
|
296
296
|
super().__init__()
|
297
297
|
|
298
298
|
self.init_chs = [16, 24, 32, 64, 96]
|
299
299
|
self.hidd_chs = hidd_chs
|
300
|
+
self.spike_dim = spike_dim
|
300
301
|
self.attn_num_splits = 1
|
301
302
|
|
302
303
|
self.N_group = 3
|
303
|
-
|
304
|
-
|
304
|
+
if spike_dim == 61:
|
305
|
+
self.resnet = ResidualBlock(in_channles=21, num_channles=11, use_1x1conv=True)
|
306
|
+
dim_tfp = 16 # 5 + num_channels
|
307
|
+
elif spike_dim == 41:
|
308
|
+
self.resnet = ResidualBlock(in_channles=15, num_channles=11, use_1x1conv=True)
|
309
|
+
dim_tfp = 15 # 4 + num_channels
|
305
310
|
self.encoder = ImageEncoder(in_chs=dim_tfp, init_chs=self.init_chs)
|
306
311
|
|
307
312
|
self.transformer = CrossTransformerBlock(dim=self.init_chs[-1], num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
|
@@ -314,14 +319,16 @@ class STIR(BasicModel):
|
|
314
319
|
self.win_r = win_r
|
315
320
|
self.win_step = win_step
|
316
321
|
|
317
|
-
self.resnet = ResidualBlock(in_channles=21, num_channles=11, use_1x1conv=True)
|
318
|
-
|
319
322
|
def forward(self, x):
|
320
323
|
b,_,h,w=x.size()
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
324
|
+
if self.spike_dim == 61:
|
325
|
+
block1 = x[:, 0 : 21, :, :]
|
326
|
+
block2 = x[:, 20 : 41, :, :]
|
327
|
+
block3 = x[:, 40 : 61, :, :]
|
328
|
+
elif self.spike_dim == 41:
|
329
|
+
block1 = x[:, 0 : 15, :, :]
|
330
|
+
block2 = x[:, 13 : 28, :, :]
|
331
|
+
block3 = x[:, 26 : 41, :, :]
|
325
332
|
|
326
333
|
repre1 = TFP(block1, channel_step=2)#C: 5
|
327
334
|
repre2 = TFP(block2, channel_step=2)
|
File without changes
|
File without changes
|
@@ -94,15 +94,15 @@ class Dwt1dModule_Tcn(nn.Module):
|
|
94
94
|
class Dwt1dResnetX_TCN(nn.Module):
|
95
95
|
def __init__(
|
96
96
|
self,
|
97
|
-
wvlname='
|
98
|
-
J=
|
99
|
-
yl_size=
|
100
|
-
yh_size=[
|
101
|
-
num_residual_blocks=
|
97
|
+
wvlname='db8',
|
98
|
+
J=5,
|
99
|
+
yl_size=15,
|
100
|
+
yh_size=[28, 21, 18, 16, 15],
|
101
|
+
num_residual_blocks=3,
|
102
102
|
norm=None,
|
103
103
|
inc=41,
|
104
104
|
ks=3,
|
105
|
-
store_features=
|
105
|
+
store_features=True
|
106
106
|
):
|
107
107
|
super().__init__()
|
108
108
|
|
@@ -4,28 +4,30 @@ import importlib, inspect
|
|
4
4
|
import os
|
5
5
|
import torch
|
6
6
|
from typing import Literal
|
7
|
+
from spikezoo.utils.other_utils import getattr_case_insensitive
|
7
8
|
|
8
9
|
# todo auto detect/register datasets
|
9
10
|
files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
|
10
11
|
dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
|
11
12
|
|
13
|
+
|
12
14
|
# todo register function
|
13
15
|
def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
|
14
16
|
"""Build the dataset from the given dataset config."""
|
15
17
|
# build new cfg according to split
|
16
|
-
cfg = replace(cfg,split
|
18
|
+
cfg = replace(cfg, split=split, spike_length=cfg.spike_length_train if split == "train" else cfg.spike_length_test)
|
17
19
|
# dataset module
|
18
20
|
module_name = cfg.dataset_name + "_dataset"
|
19
21
|
assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
|
20
22
|
module_name = "spikezoo.datasets." + module_name
|
21
23
|
module = importlib.import_module(module_name)
|
22
24
|
# dataset,dataset_config
|
23
|
-
|
24
|
-
|
25
|
+
dataset_name = cfg.dataset_name
|
26
|
+
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
27
|
+
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
25
28
|
dataset = dataset_cls(cfg)
|
26
29
|
return dataset
|
27
30
|
|
28
|
-
|
29
31
|
def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
|
30
32
|
"""Build the default dataset from the given name."""
|
31
33
|
module_name = dataset_name + "_dataset"
|
@@ -33,21 +35,21 @@ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "tes
|
|
33
35
|
module_name = "spikezoo.datasets." + module_name
|
34
36
|
module = importlib.import_module(module_name)
|
35
37
|
# dataset,dataset_config
|
36
|
-
|
37
|
-
dataset_cls: BaseDataset =
|
38
|
-
dataset_cfg: BaseDatasetConfig =
|
38
|
+
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
39
|
+
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
40
|
+
dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")(split=split)
|
39
41
|
dataset = dataset_cls(dataset_cfg)
|
40
42
|
return dataset
|
41
43
|
|
42
44
|
|
43
45
|
# todo to modify according to the basicsr
|
44
|
-
def build_dataloader(dataset: BaseDataset,cfg
|
46
|
+
def build_dataloader(dataset: BaseDataset, cfg=None):
|
45
47
|
# train dataloader
|
46
48
|
if dataset.cfg.split == "train":
|
47
49
|
if cfg is None:
|
48
50
|
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
49
51
|
else:
|
50
|
-
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
|
52
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
|
51
53
|
# test dataloader
|
52
54
|
elif dataset.cfg.split == "test":
|
53
55
|
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|