spikezoo 0.2.2__tar.gz → 0.2.3__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (230) hide show
  1. spikezoo-0.2.3/PKG-INFO +263 -0
  2. spikezoo-0.2.3/spikezoo/__init__.py +29 -0
  3. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
  4. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/models/bsf/rep.py +2 -2
  5. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/nets.py +1 -1
  6. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/models/networks.py +1 -1
  7. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssml/model.py +9 -5
  8. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/metrics/losses.py +1 -1
  9. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/models/networks_STIR.py +16 -9
  10. {spikezoo-0.2.2/spikezoo/archs/spikeformer/Metrics → spikezoo-0.2.3/spikezoo/archs/stir/package_core/build/lib/package_core}/__init__.py +0 -0
  11. {spikezoo-0.2.2/spikezoo/archs/spikeformer/Model → spikezoo-0.2.3/spikezoo/archs/stir/package_core/package_core}/__init__.py +0 -0
  12. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/tfi/nets.py +1 -1
  13. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/tfp/nets.py +1 -1
  14. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/dwtnets.py +6 -6
  15. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/datasets/__init__.py +11 -9
  16. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/datasets/base_dataset.py +10 -3
  17. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/datasets/realworld_dataset.py +1 -3
  18. spikezoo-0.2.2/spikezoo/datasets/reds_small_dataset.py → spikezoo-0.2.3/spikezoo/datasets/reds_base_dataset.py +9 -8
  19. spikezoo-0.2.3/spikezoo/datasets/reds_ssir_dataset.py +181 -0
  20. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/datasets/szdata_dataset.py +5 -15
  21. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/datasets/uhsr_dataset.py +4 -3
  22. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/models/__init__.py +8 -6
  23. spikezoo-0.2.3/spikezoo/models/base_model.py +231 -0
  24. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/models/bsf_model.py +11 -3
  25. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/models/spikeclip_model.py +4 -3
  26. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/models/spk2imgnet_model.py +9 -15
  27. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/models/ssir_model.py +4 -6
  28. spikezoo-0.2.3/spikezoo/models/ssml_model.py +60 -0
  29. spikezoo-0.2.3/spikezoo/models/stir_model.py +58 -0
  30. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/models/tfi_model.py +3 -1
  31. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/models/tfp_model.py +4 -2
  32. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/models/wgse_model.py +8 -14
  33. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/pipeline/base_pipeline.py +79 -55
  34. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/pipeline/ensemble_pipeline.py +10 -9
  35. spikezoo-0.2.3/spikezoo/pipeline/train_cfgs.py +89 -0
  36. spikezoo-0.2.3/spikezoo/pipeline/train_pipeline.py +193 -0
  37. spikezoo-0.2.3/spikezoo/utils/optimizer_utils.py +22 -0
  38. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/utils/other_utils.py +31 -6
  39. spikezoo-0.2.3/spikezoo/utils/scheduler_utils.py +25 -0
  40. spikezoo-0.2.3/spikezoo/utils/spike_utils.py +118 -0
  41. spikezoo-0.2.3/spikezoo.egg-info/PKG-INFO +263 -0
  42. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo.egg-info/SOURCES.txt +5 -47
  43. spikezoo-0.2.2/MANIFEST.in +0 -5
  44. spikezoo-0.2.2/PKG-INFO +0 -196
  45. spikezoo-0.2.2/README.md +0 -159
  46. spikezoo-0.2.2/requirements.txt +0 -18
  47. spikezoo-0.2.2/setup.py +0 -23
  48. spikezoo-0.2.2/spikezoo/__init__.py +0 -13
  49. spikezoo-0.2.2/spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  50. spikezoo-0.2.2/spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  51. spikezoo-0.2.2/spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  52. spikezoo-0.2.2/spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  53. spikezoo-0.2.2/spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  54. spikezoo-0.2.2/spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo-0.2.2/spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
  56. spikezoo-0.2.2/spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
  57. spikezoo-0.2.2/spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
  58. spikezoo-0.2.2/spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
  59. spikezoo-0.2.2/spikezoo/archs/spikeformer/EvalResults/readme +0 -1
  60. spikezoo-0.2.2/spikezoo/archs/spikeformer/LICENSE +0 -21
  61. spikezoo-0.2.2/spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
  62. spikezoo-0.2.2/spikezoo/archs/spikeformer/Model/Loss.py +0 -89
  63. spikezoo-0.2.2/spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
  64. spikezoo-0.2.2/spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  65. spikezoo-0.2.2/spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  66. spikezoo-0.2.2/spikezoo/archs/spikeformer/README.md +0 -30
  67. spikezoo-0.2.2/spikezoo/archs/spikeformer/evaluate.py +0 -87
  68. spikezoo-0.2.2/spikezoo/archs/spikeformer/recon_real_data.py +0 -97
  69. spikezoo-0.2.2/spikezoo/archs/spikeformer/requirements.yml +0 -95
  70. spikezoo-0.2.2/spikezoo/archs/spikeformer/train.py +0 -173
  71. spikezoo-0.2.2/spikezoo/archs/spikeformer/utils.py +0 -22
  72. spikezoo-0.2.2/spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  73. spikezoo-0.2.2/spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  74. spikezoo-0.2.2/spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  75. spikezoo-0.2.2/spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  76. spikezoo-0.2.2/spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  77. spikezoo-0.2.2/spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  78. spikezoo-0.2.2/spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  79. spikezoo-0.2.2/spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  80. spikezoo-0.2.2/spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  81. spikezoo-0.2.2/spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  82. spikezoo-0.2.2/spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  83. spikezoo-0.2.2/spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  84. spikezoo-0.2.2/spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  85. spikezoo-0.2.2/spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  86. spikezoo-0.2.2/spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  87. spikezoo-0.2.2/spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  88. spikezoo-0.2.2/spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  89. spikezoo-0.2.2/spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  90. spikezoo-0.2.2/spikezoo/models/base_model.py +0 -175
  91. spikezoo-0.2.2/spikezoo/models/spikeformer_model.py +0 -50
  92. spikezoo-0.2.2/spikezoo/models/ssml_model.py +0 -18
  93. spikezoo-0.2.2/spikezoo/models/stir_model.py +0 -37
  94. spikezoo-0.2.2/spikezoo/pipeline/train_pipeline.py +0 -94
  95. spikezoo-0.2.2/spikezoo/utils/spike_utils.py +0 -86
  96. spikezoo-0.2.2/spikezoo.egg-info/PKG-INFO +0 -196
  97. {spikezoo-0.2.2 → spikezoo-0.2.3}/LICENSE.txt +0 -0
  98. {spikezoo-0.2.2 → spikezoo-0.2.3}/setup.cfg +0 -0
  99. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/__init__.py +0 -0
  100. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/base/nets.py +0 -0
  101. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/README.md +0 -0
  102. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/datasets/datasets.py +0 -0
  103. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/datasets/ds_utils.py +0 -0
  104. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/main.py +0 -0
  105. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/metrics/psnr.py +0 -0
  106. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/metrics/ssim.py +0 -0
  107. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/models/bsf/align.py +0 -0
  108. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/models/bsf/dsft_convert.py +0 -0
  109. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/models/get_model.py +0 -0
  110. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/prepare_data/DSFT.py +0 -0
  111. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +0 -0
  112. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +0 -0
  113. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/prepare_data/crop_train.sh +0 -0
  114. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/prepare_data/crop_val.sh +0 -0
  115. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/prepare_data/io_utils.py +0 -0
  116. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/requirements.txt +0 -0
  117. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/test.py +0 -0
  118. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/bsf/utils.py +0 -0
  119. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spikeclip/nets.py +0 -0
  120. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +0 -0
  121. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/.gitignore +0 -0
  122. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/DCNv2.py +0 -0
  123. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/align_arch.py +0 -0
  124. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/dataset.py +0 -0
  125. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/readme.md +0 -0
  126. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/test_gen_imgseq.py +0 -0
  127. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/train.py +0 -0
  128. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/spk2imgnet/utils.py +0 -0
  129. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/README.md +0 -0
  130. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/configs/SSIR.yml +0 -0
  131. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/configs/yml_parser.py +0 -0
  132. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/datasets/dataset_sreds.py +0 -0
  133. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/datasets/ds_utils.py +0 -0
  134. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/losses.py +0 -0
  135. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/main.py +0 -0
  136. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/metrics/psnr.py +0 -0
  137. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/metrics/ssim.py +0 -0
  138. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/models/Vgg19.py +0 -0
  139. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/models/layers.py +0 -0
  140. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/requirements.txt +0 -0
  141. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/shells/eval_SREDS.sh +0 -0
  142. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/shells/train_SSIR.sh +0 -0
  143. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/test.py +0 -0
  144. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssir/utils.py +0 -0
  145. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssml/cbam.py +0 -0
  146. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssml/res.png +0 -0
  147. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/ssml/test.py +0 -0
  148. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/.git-credentials +0 -0
  149. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/README.md +0 -0
  150. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +0 -0
  151. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/configs/STIR.yml +0 -0
  152. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/configs/utils.py +0 -0
  153. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/configs/yml_parser.py +0 -0
  154. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/datasets/dataset_sreds.py +0 -0
  155. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/datasets/ds_utils.py +0 -0
  156. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/eval_SREDS.sh +0 -0
  157. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/main.py +0 -0
  158. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/metrics/psnr.py +0 -0
  159. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/metrics/ssim.py +0 -0
  160. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/models/Vgg19.py +0 -0
  161. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/models/submodules.py +0 -0
  162. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/models/transformer_new.py +0 -0
  163. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +0 -0
  164. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +0 -0
  165. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +0 -0
  166. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +0 -0
  167. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +0 -0
  168. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +0 -0
  169. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +0 -0
  170. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +0 -0
  171. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +0 -0
  172. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +0 -0
  173. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +0 -0
  174. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +0 -0
  175. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +0 -0
  176. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +0 -0
  177. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  178. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/convertions.py +0 -0
  179. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/disp_netS.py +0 -0
  180. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/flow_utils.py +0 -0
  181. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/generic_train_test.py +0 -0
  182. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/geometry.py +0 -0
  183. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/image_proc.py +0 -0
  184. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/linalg.py +0 -0
  185. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/losses.py +0 -0
  186. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/metrics.py +0 -0
  187. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/model_base.py +0 -0
  188. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/net_basics.py +0 -0
  189. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/resnet.py +0 -0
  190. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/transforms.py +0 -0
  191. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core/utils.py +0 -0
  192. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +0 -0
  193. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +0 -0
  194. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +0 -0
  195. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +0 -0
  196. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/package_core/setup.py +0 -0
  197. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/requirements.txt +0 -0
  198. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/stir/train_STIR.sh +0 -0
  199. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/README.md +0 -0
  200. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/dataset.py +0 -0
  201. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/demo.png +0 -0
  202. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/demo.py +0 -0
  203. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/eval.py +0 -0
  204. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +0 -0
  205. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/submodules.py +0 -0
  206. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/train.py +0 -0
  207. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/transform.py +0 -0
  208. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/utils.py +0 -0
  209. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/archs/wgse/weights/demo.png +0 -0
  210. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  211. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  212. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  213. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  214. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  215. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  216. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  217. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  218. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  219. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  220. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  221. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  222. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/metrics/__init__.py +0 -0
  223. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/pipeline/__init__.py +0 -0
  224. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/utils/__init__.py +0 -0
  225. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/utils/data_utils.py +0 -0
  226. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/utils/img_utils.py +0 -0
  227. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  228. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo.egg-info/dependency_links.txt +0 -0
  229. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo.egg-info/requires.txt +0 -0
  230. {spikezoo-0.2.2 → spikezoo-0.2.3}/spikezoo.egg-info/top_level.txt +0 -0
@@ -0,0 +1,263 @@
1
+ Metadata-Version: 2.2
2
+ Name: spikezoo
3
+ Version: 0.2.3
4
+ Summary: A deep learning toolbox for spike-to-image models.
5
+ Home-page: https://github.com/chenkang455/Spike-Zoo
6
+ Author: Kang Chen
7
+ Author-email: mrchenkang@stu.pku.edu.cn
8
+ Requires-Python: >=3.7
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE.txt
11
+ Requires-Dist: torch
12
+ Requires-Dist: requests
13
+ Requires-Dist: numpy
14
+ Requires-Dist: tqdm
15
+ Requires-Dist: scikit-image
16
+ Requires-Dist: lpips
17
+ Requires-Dist: pyiqa
18
+ Requires-Dist: opencv-python
19
+ Requires-Dist: thop
20
+ Requires-Dist: pytorch-wavelets
21
+ Requires-Dist: pytz
22
+ Requires-Dist: PyWavelets
23
+ Requires-Dist: pandas
24
+ Requires-Dist: pillow
25
+ Requires-Dist: scikit-learn
26
+ Requires-Dist: scipy
27
+ Requires-Dist: spikingjelly
28
+ Requires-Dist: setuptools
29
+ Dynamic: author
30
+ Dynamic: author-email
31
+ Dynamic: description
32
+ Dynamic: description-content-type
33
+ Dynamic: home-page
34
+ Dynamic: requires-dist
35
+ Dynamic: requires-python
36
+ Dynamic: summary
37
+
38
+ <p align="center">
39
+ <br>
40
+ <img src="imgs/spike-zoo.png" width="500"/>
41
+ <br>
42
+ <p>
43
+
44
+ <h5 align="center">
45
+
46
+ [![GitHub repo stars](https://img.shields.io/github/stars/chenkang455/Spike-Zoo?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/chenkang455/Spike-Zoo/stargazers) [![GitHub Issues](https://img.shields.io/github/issues/chenkang455/Spike-Zoo?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/chenkang455/Spike-Zoo/issues) <a href="https://badge.fury.io/py/spikezoo"><img src="https://badge.fury.io/py/spikezoo.svg" alt="PyPI version"></a> [![License](https://img.shields.io/badge/License-MIT-yellow)](https://github.com/chenkang455/Spike-Zoo)
47
+
48
+ <p>
49
+
50
+ <!-- <h2 align="center">
51
+ <a href="">⚡Spike-Zoo: A Toolbox for Spike-to-Image Reconstruction
52
+ </a>
53
+ </h2> -->
54
+
55
+ ## 📖 About
56
+ ⚡Spike-Zoo is the go-to library for state-of-the-art pretrained **spike-to-image** models designed to reconstruct images from spike streams. Whether you're looking for a simple inference solution or aiming to train your own spike-to-image models, ⚡Spike-Zoo is a modular toolbox that supports both, with key features including:
57
+
58
+ - Fast inference with pre-trained models.
59
+ - Training support for custom-designed spike-to-image models.
60
+ - Specialized functions for processing spike data.
61
+
62
+
63
+
64
+ ## 🚩 Updates/Changelog
65
+ * **25-02-02:** Release the `Spike-Zoo v0.2` code, which supports more methods, provide more usages like training your method from scratch.
66
+ * **24-07-19:** Release the `Spike-Zoo v0.1` code for base evaluation of SOTA methods.
67
+
68
+ ## 🍾 Quick Start
69
+ ### 1. Installation
70
+ For users focused on **utilizing pretrained models for spike-to-image conversion**, we recommend installing SpikeZoo using one of the following methods:
71
+
72
+ * Install the last stable version `0.2.3` from PyPI:
73
+ ```
74
+ pip install spikezoo
75
+ ```
76
+ * Install the latest developing version `0.2.3` from the source code :
77
+ ```
78
+ git clone https://github.com/chenkang455/Spike-Zoo
79
+ cd Spike-Zoo
80
+ python setup.py install
81
+ ```
82
+
83
+ For users interested in **training their own spike-to-image model based on our framework**, we recommend cloning the repository and modifying the related code directly.
84
+ ```
85
+ git clone https://github.com/chenkang455/Spike-Zoo
86
+ cd Spike-Zoo
87
+ python setup.py develop
88
+ ```
89
+
90
+ ### 2. Inference
91
+ Reconstructing images from the spike is super easy with Spike-Zoo. Try the following code of the single model:
92
+ ``` python
93
+ from spikezoo.pipeline import Pipeline, PipelineConfig
94
+ import spikezoo as sz
95
+ pipeline = Pipeline(
96
+ cfg=PipelineConfig(save_folder="results",version="v023"),
97
+ model_cfg=sz.METHOD.BASE,
98
+ dataset_cfg=sz.DATASET.BASE
99
+ )
100
+ ```
101
+ You can also run multiple models at once by changing the pipeline (version parameter corresponds to our released different versions in [Releases](https://github.com/chenkang455/Spike-Zoo/releases)):
102
+ ``` python
103
+ import spikezoo as sz
104
+ from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig
105
+ pipeline = EnsemblePipeline(
106
+ cfg=EnsemblePipelineConfig(save_folder="results",version="v023"),
107
+ model_cfg_list=[
108
+ sz.METHOD.BASE,sz.METHOD.TFP,sz.METHOD.TFI,sz.METHOD.SPK2IMGNET,sz.METHOD.WGSE,
109
+ sz.METHOD.SSML,sz.METHOD.BSF,sz.METHOD.STIR,sz.METHOD.SPIKECLIP,sz.METHOD.SSIR],
110
+ dataset_cfg=sz.DATASET.BASE,
111
+ )
112
+ ```
113
+ Having established our pipelines, we provide following functions to enjoy these spike-to-image models.
114
+
115
+ * I. Obtain the restoration metric and save the recovered image from the given spike:
116
+ ``` python
117
+ # 1. spike-to-image from the given dataset
118
+ pipeline.infer_from_dataset(idx = 0)
119
+
120
+ # 2. spike-to-image from the given .dat file
121
+ pipeline.infer_from_file(file_path = 'data/scissor.dat',width = 400,height=250)
122
+
123
+ # 3. spike-to-image from the given spike
124
+ import spikezoo as sz
125
+ spike = sz.load_vidar_dat("data/scissor.dat",width = 400,height = 250)
126
+ pipeline.infer_from_spk(spike)
127
+ ```
128
+
129
+
130
+ * II. Save all images from the given dataset.
131
+ ``` python
132
+ pipeline.save_imgs_from_dataset()
133
+ ```
134
+
135
+ * III. Calculate the metrics for the specified dataset.
136
+ ``` python
137
+ pipeline.cal_metrics()
138
+ ```
139
+
140
+ * IV. Calculate the parameters (params,flops,latency) based on the established pipeline.
141
+ ``` python
142
+ pipeline.cal_params()
143
+ ```
144
+
145
+ For detailed usage, welcome check [test_single.ipynb](examples/test/test_single.ipynb) and [test_ensemble.ipynb](examples/test/test_ensemble.ipynb).
146
+
147
+ ### 3. Training
148
+ We provide a user-friendly code for training our provided `base` model (modified from the `SpikeCLIP`) for the classic `REDS` dataset introduced in `Spk2ImgNet`:
149
+ ``` python
150
+ from spikezoo.pipeline import TrainPipelineConfig, TrainPipeline
151
+ from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig
152
+ from spikezoo.models.base_model import BaseModelConfig
153
+ pipeline = TrainPipeline(
154
+ cfg=TrainPipelineConfig(save_folder="results", epochs = 10),
155
+ dataset_cfg=REDS_BASEConfig(root_dir = "spikezoo/data/REDS_BASE"),
156
+ model_cfg=BaseModelConfig(),
157
+ )
158
+ pipeline.train()
159
+ ```
160
+ We finish the training with one 4090 GPU in `2 minutes`, achieving `32.8dB` in PSNR and `0.92` in SSIM.
161
+
162
+ > 🌟 We encourage users to develop their models with simple modifications to our framework, and the tutorial will be released soon.
163
+
164
+ We retrain all supported methods except `SPIKECLIP` on this REDS dataset (training scripts are placed on [examples/train_reds_base](examples/train_reds_base) and evaluation script is placed on [test_REDS_base.py](examples/test/test_REDS_base.py)), with our reported metrics as follows:
165
+
166
+ | Method | PSNR | SSIM | LPIPS | NIQE | BRISQUE | PIQE | Params (M) | FLOPs (G) | Latency (ms) |
167
+ |----------------------|:-------:|:--------:|:---------:|:---------:|:----------:|:-------:|:------------:|:-----------:|:--------------:|
168
+ | `TFI` | 16.503 | 0.454 | 0.382 | 7.289 | 43.17 | 49.12 | 0.00 | 0.00 | 3.60 |
169
+ | `TFP` | 24.287 | 0.644 | 0.274 | 8.197 | 48.48 | 38.38 | 0.00 | 0.00 | 0.03 |
170
+ | `SPIKECLIP` | 21.873 | 0.578 | 0.333 | 7.802 | 42.08 | 54.01 | 0.19 | 23.69 | 1.27 |
171
+ | `SSIR` | 26.544 | 0.718 | 0.325 | 4.769 | 28.45 | 21.59 | 0.38 | 25.92 | 4.52 |
172
+ | `SSML` | 33.697 | 0.943 | 0.088 | 4.669 | 32.48 | 37.30 | 2.38 | 386.02 | 244.18 |
173
+ | `BASE` | 36.589 | 0.965 | 0.034 | 4.393 | 26.16 | 38.43 | 0.18 | 18.04 | 0.40 |
174
+ | `STIR` | 37.914 | 0.973 | 0.027 | 4.236 | 25.10 | 39.18 | 5.08 | 43.31 | 21.07 |
175
+ | `WGSE` | 39.036 | 0.978 | 0.023 | 4.231 | 25.76 | 44.11 | 3.81 | 415.26 | 73.62 |
176
+ | `SPK2IMGNET` | 39.154 | 0.978 | 0.022 | 4.243 | 25.20 | 43.09 | 3.90 | 1000.50 | 123.38 |
177
+ | `BSF` | 39.576 | 0.979 | 0.019 | 4.139 | 24.93 | 43.03 | 2.47 | 705.23 | 401.50 |
178
+
179
+ ### 4. Model Usage
180
+ We also provide a direct interface for users interested in taking the spike-to-image model as a part of their work:
181
+
182
+ ```python
183
+ import spikezoo as sz
184
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
185
+ # input data
186
+ spike = sz.load_vidar_dat("data/data.dat", width=400, height=250, out_format="tensor")
187
+ spike = spike[None].cuda()
188
+ print(f"Input spike shape: {spike.shape}")
189
+ # net
190
+ net = BaseModel(BaseModelConfig(model_params={"inDim": 41}))
191
+ net.build_network(mode = "debug")
192
+ # process
193
+ recon_img = net(spike)
194
+ print(recon_img.shape,recon_img.max(),recon_img.min())
195
+ ```
196
+ For detailed usage, welcome check [test_model.ipynb](examples/test/test_model.ipynb).
197
+
198
+ ### 5. Spike Utility
199
+ #### I. Faster spike loading interface
200
+ We provide a faster `load_vidar_dat` function implemented with `cpp` (by [@zeal-ye](https://github.com/zeal-ye)):
201
+ ``` python
202
+ import spikezoo as sz
203
+ spike = sz.load_vidar_dat("data/scissor.dat",width = 400,height = 250,version='cpp')
204
+ ```
205
+ 🚀 Results on [test_load_dat.py](examples/test_load_dat.py) show that the `cpp` version is more than 10 times faster than the `python` version.
206
+
207
+ #### II. Spike simulation pipeline.
208
+ We provide our overall spike simulation pipeline in [scripts](scripts/), try to modify the config in `run.sh` and run the command to start the simulation process:
209
+ ``` bash
210
+ bash run.sh
211
+ ```
212
+
213
+ #### III. Spike-related functions.
214
+ For other spike-related functions, welcome check [spike_utils.py](spikezoo/utils/spike_utils.py)
215
+
216
+ ## 📅 TODO
217
+ - [x] Support the overall pipeline for spike simulation.
218
+ - [ ] Provide the tutorials.
219
+ - [ ] Support more training settings.
220
+ - [ ] Support more spike-based image reconstruction methods and datasets.
221
+
222
+ ## 🤗 Supports
223
+ Run the following code to find our supported models, datasets and metrics:
224
+ ``` python
225
+ import spikezoo as sz
226
+ print(sz.METHODS)
227
+ print(sz.DATASETS)
228
+ print(sz.METRICS)
229
+ ```
230
+ **Supported Models:**
231
+ | Models | Source
232
+ | ---- | ---- |
233
+ | `tfp`,`tfi` | Spike camera and its coding methods |
234
+ | `spk2imgnet` | Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream |
235
+ | `wgse` | Learning Temporal-Ordered Representation for Spike Streams Based on Discrete Wavelet Transforms |
236
+ | `ssml` | Self-Supervised Mutual Learning for Dynamic Scene Reconstruction of Spiking Camera |
237
+ | `ssir` | Spike Camera Image Reconstruction Using Deep Spiking Neural Networks |
238
+ | `bsf` | Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations |
239
+ | `stir` | Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras |
240
+ | `base`,`spikeclip` | Rethinking High-speed Image Reconstruction Framework with Spike Camera |
241
+
242
+ **Supported Datasets:**
243
+ | Datasets | Source
244
+ | ---- | ---- |
245
+ | `reds_base` | Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream |
246
+ | `uhsr` | Recognizing Ultra-High-Speed Moving Objects with Bio-Inspired Spike Camera |
247
+ | `realworld` | `recVidarReal2019`,`momVidarReal2021` in [SpikeCV](https://github.com/Zyj061/SpikeCV) |
248
+ | `szdata` | SpikeReveal: Unlocking Temporal Sequences from Real Blurry Inputs with Spike Streams |
249
+
250
+
251
+ ## ✨‍ Acknowledgment
252
+ Our code is built on the open-source projects of [SpikeCV](https://spikecv.github.io/), [IQA-Pytorch](https://github.com/chaofengc/IQA-PyTorch), [BasicSR](https://github.com/XPixelGroup/BasicSR) and [NeRFStudio](https://github.com/nerfstudio-project/nerfstudio).We appreciate the effort of the contributors to these repositories. Thanks for [@ruizhao26](https://github.com/ruizhao26), [@shiyan_chen](https://github.com/hnmizuho) and [@Leozhangjiyuan](https://github.com/Leozhangjiyuan) for their help in building this project.
253
+
254
+ ## 📑 Citation
255
+ If you find our codes helpful to your research, please consider to use the following citation:
256
+ ```
257
+ @misc{spikezoo,
258
+ title={{Spike-Zoo}: Spike-Zoo: A Toolbox for Spike-to-Image Reconstruction},
259
+ author={Kang Chen and Zhiyuan Ye},
260
+ year={2025},
261
+ howpublished = "[Online]. Available: \url{https://github.com/chenkang455/Spike-Zoo}"
262
+ }
263
+ ```
@@ -0,0 +1,29 @@
1
+ from .utils.spike_utils import *
2
+ from .models import model_list
3
+ from .datasets import dataset_list
4
+ from .metrics import metric_all_names
5
+
6
+ # METHOD NAME DEFINITION
7
+ METHODS = model_list
8
+ class METHOD:
9
+ BASE = "base"
10
+ TFP = "tfp"
11
+ TFI = "tfi"
12
+ SPK2IMGNET = "spk2imgnet"
13
+ WGSE = "wgse"
14
+ SSML = "ssml"
15
+ BSF = "bsf"
16
+ STIR = "stir"
17
+ SSIR = "ssir"
18
+ SPIKECLIP = "spikeclip"
19
+
20
+ # DATASET NAME DEFINITION
21
+ DATASETS = dataset_list
22
+ class DATASET:
23
+ BASE = "base"
24
+ REDS_BASE = "reds_base"
25
+ REALWORLD = "realworld"
26
+ UHSR = "uhsr"
27
+
28
+ # METRIC NAME DEFINITION
29
+ METRICS = metric_all_names
@@ -8,18 +8,18 @@ from .align import Multi_Granularity_Align
8
8
  class BasicModel(nn.Module):
9
9
  def __init__(self):
10
10
  super().__init__()
11
-
11
+
12
12
  ####################################################################################
13
13
  ## Tools functions for neural networks
14
14
  def weight_parameters(self):
15
- return [param for name, param in self.named_parameters() if 'weight' in name]
15
+ return [param for name, param in self.named_parameters() if "weight" in name]
16
16
 
17
17
  def bias_parameters(self):
18
- return [param for name, param in self.named_parameters() if 'bias' in name]
18
+ return [param for name, param in self.named_parameters() if "bias" in name]
19
19
 
20
20
  def num_parameters(self):
21
21
  return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
22
-
22
+
23
23
  def init_weights(self):
24
24
  for layer in self.named_modules():
25
25
  if isinstance(layer, nn.Conv2d):
@@ -33,12 +33,21 @@ class BasicModel(nn.Module):
33
33
  nn.init.constant_(layer.bias, 0)
34
34
 
35
35
 
36
- def split_and_b_cat(x):
37
- x0 = x[:, 10-10:10+10+1].clone()
38
- x1 = x[:, 20-10:20+10+1].clone()
39
- x2 = x[:, 30-10:30+10+1].clone()
40
- x3 = x[:, 40-10:40+10+1].clone()
41
- x4 = x[:, 50-10:50+10+1].clone()
36
+ from typing import Literal
37
+
38
+
39
+ def split_and_b_cat(x, spike_dim: Literal[41, 61] = 61):
40
+ if spike_dim == 61:
41
+ win_r = 10
42
+ win_step = 10
43
+ elif spike_dim == 41:
44
+ win_r = 6
45
+ win_step = 7
46
+ x0 = x[:, 0 : 2 * win_r + 1, :, :].clone()
47
+ x1 = x[:, win_step : win_step + 2 * win_r + 1, :, :].clone()
48
+ x2 = x[:, 2 * win_step : 2 * win_step + 2 * win_r + 1, :, :].clone()
49
+ x3 = x[:, 3 * win_step : 3 * win_step + 2 * win_r + 1, :, :].clone()
50
+ x4 = x[:, 4 * win_step : 4 * win_step + 2 * win_r + 1, :, :].clone()
42
51
  return torch.cat([x0, x1, x2, x3, x4], dim=0)
43
52
 
44
53
 
@@ -61,39 +70,42 @@ class Encoder(nn.Module):
61
70
  x = self.act(conv(x) + x)
62
71
  return x
63
72
 
73
+
64
74
  ##########################################################################
65
75
  class BSF(BasicModel):
66
- def __init__(self, act=nn.ReLU()):
76
+ def __init__(self, spike_dim=61, act=nn.ReLU()):
67
77
  super().__init__()
78
+ self.spike_dim = spike_dim
68
79
  self.offset_groups = 4
69
80
  self.corr_max_disp = 3
70
-
71
- self.rep = MODF(base_dim=64, act=act)
72
-
81
+ if spike_dim == 61:
82
+ self.rep = MODF(in_dim=21,base_dim=64, act=act)
83
+ elif spike_dim == 41:
84
+ self.rep = MODF(in_dim=13,base_dim=64, act=act)
73
85
  self.encoder = Encoder(base_dim=64, layers=4, act=act)
74
86
 
75
87
  self.align = Multi_Granularity_Align(base_dim=64, groups=self.offset_groups, act=act, sc=3)
76
88
 
77
89
  self.recons = nn.Sequential(
78
- nn.Conv2d(64*5, 64*3, kernel_size=3, padding=1),
90
+ nn.Conv2d(64 * 5, 64 * 3, kernel_size=3, padding=1),
79
91
  act,
80
- nn.Conv2d(64*3, 64, kernel_size=3, padding=1),
92
+ nn.Conv2d(64 * 3, 64, kernel_size=3, padding=1),
81
93
  act,
82
94
  nn.Conv2d(64, 1, kernel_size=3, padding=1),
83
95
  )
84
96
 
85
97
  def forward(self, input_dict):
86
- dsft_dict = input_dict['dsft_dict']
87
- dsft11 = dsft_dict['dsft11']
88
- dsft12 = dsft_dict['dsft12']
89
- dsft21 = dsft_dict['dsft21']
90
- dsft22 = dsft_dict['dsft22']
98
+ dsft_dict = input_dict["dsft_dict"]
99
+ dsft11 = dsft_dict["dsft11"]
100
+ dsft12 = dsft_dict["dsft12"]
101
+ dsft21 = dsft_dict["dsft21"]
102
+ dsft22 = dsft_dict["dsft22"]
91
103
 
92
104
  dsft_b_cat = {
93
- 'dsft11': split_and_b_cat(dsft11),
94
- 'dsft12': split_and_b_cat(dsft12),
95
- 'dsft21': split_and_b_cat(dsft21),
96
- 'dsft22': split_and_b_cat(dsft22),
105
+ "dsft11": split_and_b_cat(dsft11, self.spike_dim),
106
+ "dsft12": split_and_b_cat(dsft12, self.spike_dim),
107
+ "dsft21": split_and_b_cat(dsft21, self.spike_dim),
108
+ "dsft22": split_and_b_cat(dsft22, self.spike_dim),
97
109
  }
98
110
 
99
111
  feat_b_cat = self.rep(dsft_b_cat)
@@ -2,11 +2,11 @@ import torch
2
2
  import torch.nn as nn
3
3
 
4
4
  class MODF(nn.Module):
5
- def __init__(self, base_dim=64, act=nn.ReLU()):
5
+ def __init__(self, in_dim = 21, base_dim=64, act=nn.ReLU()):
6
6
  super().__init__()
7
7
  self.base_dim = base_dim
8
8
 
9
- self.conv1 = self._make_layer(input_dim=21, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
9
+ self.conv1 = self._make_layer(input_dim=in_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
10
10
  self.conv_for_others = nn.ModuleList([
11
11
  self._make_layer(input_dim=self.base_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act) for ii in range(3)
12
12
  ])
@@ -167,7 +167,7 @@ class FusionMaskV1(nn.Module):
167
167
 
168
168
  # current best model
169
169
  class SpikeNet(nn.Module):
170
- def __init__(self, in_channels, features, out_channels, win_r, win_step):
170
+ def __init__(self, in_channels = 13, features = 64, out_channels = 1, win_r = 6, win_step = 7):
171
171
  super(SpikeNet, self).__init__()
172
172
  self.extractor = FeatureExtractor(
173
173
  in_channels=in_channels,
@@ -56,6 +56,6 @@ class SSIR(BasicModel):
56
56
  out3 = self.pred3(x7)
57
57
 
58
58
  if self.training:
59
- return [out3]
59
+ return out3
60
60
  else:
61
61
  return out3
@@ -272,18 +272,22 @@ class BSN(nn.Module):
272
272
  diff = W - H
273
273
  x0 = x0[:, :, (diff // 2):(diff // 2 + H), 0:W]
274
274
 
275
- return x0,tfi_label,tfp_label
275
+ return x0
276
276
 
277
277
  class DoubleNet(nn.Module):
278
278
  def __init__(self):
279
279
  super().__init__()
280
280
  self.nbsn = BSN(n_channels=41, n_output=1,blind=False)
281
- # self.bsn = BSN(n_channels=41, n_output=1,blind=True)
281
+ self.bsn = BSN(n_channels=41, n_output=1,blind=True)
282
282
 
283
283
  def forward(self, x):
284
- out1,_,_ = self.nbsn(x)
285
-
286
- return out1
284
+ if self.training:
285
+ bsn_pred = self.bsn(x)
286
+ nbsn_pred = self.nbsn(x)
287
+ return bsn_pred,nbsn_pred
288
+ else:
289
+ nbsn_pred = self.nbsn(x)
290
+ return nbsn_pred
287
291
 
288
292
  if __name__ == '__main__':
289
293
  a=DoubleNet().cuda()
@@ -6,7 +6,7 @@ import torch.nn.functional as F
6
6
 
7
7
  import math
8
8
 
9
- from package_core.losses import *
9
+ from ..package_core.package_core.losses import *
10
10
 
11
11
  def compute_l1_loss(img_list, gt):
12
12
  l1_loss = 0.0
@@ -292,16 +292,21 @@ class STIRDecorder(nn.Module):#second and third levels
292
292
 
293
293
  ##############################Our Model####################################
294
294
  class STIR(BasicModel):
295
- def __init__(self, hidd_chs=8, win_r=6, win_step=7):
295
+ def __init__(self, spike_dim = 61,hidd_chs=8, win_r=6, win_step=7):
296
296
  super().__init__()
297
297
 
298
298
  self.init_chs = [16, 24, 32, 64, 96]
299
299
  self.hidd_chs = hidd_chs
300
+ self.spike_dim = spike_dim
300
301
  self.attn_num_splits = 1
301
302
 
302
303
  self.N_group = 3
303
-
304
- dim_tfp = 16
304
+ if spike_dim == 61:
305
+ self.resnet = ResidualBlock(in_channles=21, num_channles=11, use_1x1conv=True)
306
+ dim_tfp = 16 # 5 + num_channels
307
+ elif spike_dim == 41:
308
+ self.resnet = ResidualBlock(in_channles=15, num_channles=11, use_1x1conv=True)
309
+ dim_tfp = 15 # 4 + num_channels
305
310
  self.encoder = ImageEncoder(in_chs=dim_tfp, init_chs=self.init_chs)
306
311
 
307
312
  self.transformer = CrossTransformerBlock(dim=self.init_chs[-1], num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
@@ -314,14 +319,16 @@ class STIR(BasicModel):
314
319
  self.win_r = win_r
315
320
  self.win_step = win_step
316
321
 
317
- self.resnet = ResidualBlock(in_channles=21, num_channles=11, use_1x1conv=True)
318
-
319
322
  def forward(self, x):
320
323
  b,_,h,w=x.size()
321
-
322
- block1 = x[:, 0 : 21, :, :]
323
- block2 = x[:, 20 : 41, :, :]
324
- block3 = x[:, 40 : 61, :, :]
324
+ if self.spike_dim == 61:
325
+ block1 = x[:, 0 : 21, :, :]
326
+ block2 = x[:, 20 : 41, :, :]
327
+ block3 = x[:, 40 : 61, :, :]
328
+ elif self.spike_dim == 41:
329
+ block1 = x[:, 0 : 15, :, :]
330
+ block2 = x[:, 13 : 28, :, :]
331
+ block3 = x[:, 26 : 41, :, :]
325
332
 
326
333
  repre1 = TFP(block1, channel_step=2)#C: 5
327
334
  repre2 = TFP(block2, channel_step=2)
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
 
8
8
  class TFIModel(nn.Module):
9
- def __init__(self, model_win_length):
9
+ def __init__(self, model_win_length = 41):
10
10
  super(TFIModel, self).__init__()
11
11
  self.window = model_win_length
12
12
  self.hald_window = model_win_length // 2
@@ -3,7 +3,7 @@ import torch
3
3
 
4
4
 
5
5
  class TFPModel(nn.Module):
6
- def __init__(self, model_win_length):
6
+ def __init__(self, model_win_length = 41):
7
7
  self.window = model_win_length
8
8
  super(TFPModel, self).__init__()
9
9
 
@@ -94,15 +94,15 @@ class Dwt1dModule_Tcn(nn.Module):
94
94
  class Dwt1dResnetX_TCN(nn.Module):
95
95
  def __init__(
96
96
  self,
97
- wvlname='db1',
98
- J=3,
99
- yl_size=14,
100
- yh_size=[26, 18, 14],
101
- num_residual_blocks=2,
97
+ wvlname='db8',
98
+ J=5,
99
+ yl_size=15,
100
+ yh_size=[28, 21, 18, 16, 15],
101
+ num_residual_blocks=3,
102
102
  norm=None,
103
103
  inc=41,
104
104
  ks=3,
105
- store_features=False
105
+ store_features=True
106
106
  ):
107
107
  super().__init__()
108
108
 
@@ -4,28 +4,30 @@ import importlib, inspect
4
4
  import os
5
5
  import torch
6
6
  from typing import Literal
7
+ from spikezoo.utils.other_utils import getattr_case_insensitive
7
8
 
8
9
  # todo auto detect/register datasets
9
10
  files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
10
11
  dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
11
12
 
13
+
12
14
  # todo register function
13
15
  def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
14
16
  """Build the dataset from the given dataset config."""
15
17
  # build new cfg according to split
16
- cfg = replace(cfg,split = split,spike_length = cfg.spike_length_train if split == "train" else cfg.spike_length_test)
18
+ cfg = replace(cfg, split=split, spike_length=cfg.spike_length_train if split == "train" else cfg.spike_length_test)
17
19
  # dataset module
18
20
  module_name = cfg.dataset_name + "_dataset"
19
21
  assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
20
22
  module_name = "spikezoo.datasets." + module_name
21
23
  module = importlib.import_module(module_name)
22
24
  # dataset,dataset_config
23
- classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
24
- dataset_cls: BaseDataset = getattr(module, classes[0])
25
+ dataset_name = cfg.dataset_name
26
+ dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
27
+ dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
25
28
  dataset = dataset_cls(cfg)
26
29
  return dataset
27
30
 
28
-
29
31
  def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
30
32
  """Build the default dataset from the given name."""
31
33
  module_name = dataset_name + "_dataset"
@@ -33,21 +35,21 @@ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "tes
33
35
  module_name = "spikezoo.datasets." + module_name
34
36
  module = importlib.import_module(module_name)
35
37
  # dataset,dataset_config
36
- classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
37
- dataset_cls: BaseDataset = getattr(module, classes[0])
38
- dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])(split=split)
38
+ dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
39
+ dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
40
+ dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")(split=split)
39
41
  dataset = dataset_cls(dataset_cfg)
40
42
  return dataset
41
43
 
42
44
 
43
45
  # todo to modify according to the basicsr
44
- def build_dataloader(dataset: BaseDataset,cfg = None):
46
+ def build_dataloader(dataset: BaseDataset, cfg=None):
45
47
  # train dataloader
46
48
  if dataset.cfg.split == "train":
47
49
  if cfg is None:
48
50
  return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
49
51
  else:
50
- return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
52
+ return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
51
53
  # test dataloader
52
54
  elif dataset.cfg.split == "test":
53
55
  return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)