rapidtide 2.9.6__py3-none-any.whl → 3.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cloud/gmscalc-HCPYA +1 -1
- cloud/mount-and-run +2 -0
- cloud/rapidtide-HCPYA +3 -3
- rapidtide/Colortables.py +538 -38
- rapidtide/OrthoImageItem.py +1094 -51
- rapidtide/RapidtideDataset.py +1709 -114
- rapidtide/__init__.py +0 -8
- rapidtide/_version.py +4 -4
- rapidtide/calccoherence.py +242 -97
- rapidtide/calcnullsimfunc.py +240 -140
- rapidtide/calcsimfunc.py +314 -129
- rapidtide/correlate.py +1211 -389
- rapidtide/data/examples/src/testLD +56 -0
- rapidtide/data/examples/src/test_findmaxlag.py +2 -2
- rapidtide/data/examples/src/test_mlregressallt.py +32 -17
- rapidtide/data/examples/src/testalign +1 -1
- rapidtide/data/examples/src/testatlasaverage +35 -7
- rapidtide/data/examples/src/testboth +21 -0
- rapidtide/data/examples/src/testcifti +11 -0
- rapidtide/data/examples/src/testdelayvar +13 -0
- rapidtide/data/examples/src/testdlfilt +25 -0
- rapidtide/data/examples/src/testfft +35 -0
- rapidtide/data/examples/src/testfileorfloat +37 -0
- rapidtide/data/examples/src/testfmri +92 -42
- rapidtide/data/examples/src/testfuncs +3 -3
- rapidtide/data/examples/src/testglmfilt +8 -6
- rapidtide/data/examples/src/testhappy +84 -51
- rapidtide/data/examples/src/testinitdelay +19 -0
- rapidtide/data/examples/src/testmodels +33 -0
- rapidtide/data/examples/src/testnewrefine +26 -0
- rapidtide/data/examples/src/testnoiseamp +2 -2
- rapidtide/data/examples/src/testppgproc +17 -0
- rapidtide/data/examples/src/testrefineonly +22 -0
- rapidtide/data/examples/src/testretro +26 -13
- rapidtide/data/examples/src/testretrolagtcs +16 -0
- rapidtide/data/examples/src/testrolloff +11 -0
- rapidtide/data/examples/src/testsimdata +45 -28
- rapidtide/data/models/model_cnn_pytorch/loss.png +0 -0
- rapidtide/data/models/model_cnn_pytorch/loss.txt +1 -0
- rapidtide/data/models/model_cnn_pytorch/model.pth +0 -0
- rapidtide/data/models/model_cnn_pytorch/model_meta.json +68 -0
- rapidtide/data/models/model_cnn_pytorch_fulldata/loss.png +0 -0
- rapidtide/data/models/model_cnn_pytorch_fulldata/loss.txt +1 -0
- rapidtide/data/models/model_cnn_pytorch_fulldata/model.pth +0 -0
- rapidtide/data/models/model_cnn_pytorch_fulldata/model_meta.json +80 -0
- rapidtide/data/models/model_cnnbp_pytorch_fullldata/loss.png +0 -0
- rapidtide/data/models/model_cnnbp_pytorch_fullldata/loss.txt +1 -0
- rapidtide/data/models/model_cnnbp_pytorch_fullldata/model.pth +0 -0
- rapidtide/data/models/model_cnnbp_pytorch_fullldata/model_meta.json +138 -0
- rapidtide/data/models/model_cnnfft_pytorch_fulldata/loss.png +0 -0
- rapidtide/data/models/model_cnnfft_pytorch_fulldata/loss.txt +1 -0
- rapidtide/data/models/model_cnnfft_pytorch_fulldata/model.pth +0 -0
- rapidtide/data/models/model_cnnfft_pytorch_fulldata/model_meta.json +128 -0
- rapidtide/data/models/model_ppgattention_pytorch_w128_fulldata/loss.png +0 -0
- rapidtide/data/models/model_ppgattention_pytorch_w128_fulldata/loss.txt +1 -0
- rapidtide/data/models/model_ppgattention_pytorch_w128_fulldata/model.pth +0 -0
- rapidtide/data/models/model_ppgattention_pytorch_w128_fulldata/model_meta.json +49 -0
- rapidtide/data/models/model_revised_tf2/model.keras +0 -0
- rapidtide/data/models/{model_serdar → model_revised_tf2}/model_meta.json +1 -1
- rapidtide/data/models/model_serdar2_tf2/model.keras +0 -0
- rapidtide/data/models/{model_serdar2 → model_serdar2_tf2}/model_meta.json +1 -1
- rapidtide/data/models/model_serdar_tf2/model.keras +0 -0
- rapidtide/data/models/{model_revised → model_serdar_tf2}/model_meta.json +1 -1
- rapidtide/data/reference/HCP1200v2_MTT_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_binmask_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_csf_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_gray_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_graylaghist.json +7 -0
- rapidtide/data/reference/HCP1200v2_graylaghist.tsv.gz +0 -0
- rapidtide/data/reference/HCP1200v2_laghist.json +7 -0
- rapidtide/data/reference/HCP1200v2_laghist.tsv.gz +0 -0
- rapidtide/data/reference/HCP1200v2_mask_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_maxcorr_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_maxtime_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_maxwidth_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_negmask_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_timepercentile_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_white_2mm.nii.gz +0 -0
- rapidtide/data/reference/HCP1200v2_whitelaghist.json +7 -0
- rapidtide/data/reference/HCP1200v2_whitelaghist.tsv.gz +0 -0
- rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1-seg2.xml +131 -0
- rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1-seg2_regions.txt +60 -0
- rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1-seg2_space-MNI152NLin6Asym_2mm.nii.gz +0 -0
- rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1_space-MNI152NLin2009cAsym_2mm.nii.gz +0 -0
- rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1_space-MNI152NLin2009cAsym_2mm_mask.nii.gz +0 -0
- rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1_space-MNI152NLin6Asym_2mm_mask.nii.gz +0 -0
- rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL2_space-MNI152NLin6Asym_2mm_mask.nii.gz +0 -0
- rapidtide/data/reference/MNI152_T1_1mm_Brain_FAST_seg.nii.gz +0 -0
- rapidtide/data/reference/MNI152_T1_1mm_Brain_Mask.nii.gz +0 -0
- rapidtide/data/reference/MNI152_T1_2mm_Brain_FAST_seg.nii.gz +0 -0
- rapidtide/data/reference/MNI152_T1_2mm_Brain_Mask.nii.gz +0 -0
- rapidtide/decorators.py +91 -0
- rapidtide/dlfilter.py +2553 -414
- rapidtide/dlfiltertorch.py +5201 -0
- rapidtide/externaltools.py +328 -13
- rapidtide/fMRIData_class.py +108 -92
- rapidtide/ffttools.py +168 -0
- rapidtide/filter.py +2704 -1462
- rapidtide/fit.py +2361 -579
- rapidtide/genericmultiproc.py +197 -0
- rapidtide/happy_supportfuncs.py +3255 -548
- rapidtide/helper_classes.py +587 -1116
- rapidtide/io.py +2569 -468
- rapidtide/linfitfiltpass.py +784 -0
- rapidtide/makelaggedtcs.py +267 -97
- rapidtide/maskutil.py +555 -25
- rapidtide/miscmath.py +835 -144
- rapidtide/multiproc.py +217 -44
- rapidtide/patchmatch.py +752 -0
- rapidtide/peakeval.py +32 -32
- rapidtide/ppgproc.py +2205 -0
- rapidtide/qualitycheck.py +353 -40
- rapidtide/refinedelay.py +854 -0
- rapidtide/refineregressor.py +939 -0
- rapidtide/resample.py +725 -204
- rapidtide/scripts/__init__.py +1 -0
- rapidtide/scripts/{adjustoffset → adjustoffset.py} +7 -2
- rapidtide/scripts/{aligntcs → aligntcs.py} +7 -2
- rapidtide/scripts/{applydlfilter → applydlfilter.py} +7 -2
- rapidtide/scripts/applyppgproc.py +28 -0
- rapidtide/scripts/{atlasaverage → atlasaverage.py} +7 -2
- rapidtide/scripts/{atlastool → atlastool.py} +7 -2
- rapidtide/scripts/{calcicc → calcicc.py} +7 -2
- rapidtide/scripts/{calctexticc → calctexticc.py} +7 -2
- rapidtide/scripts/{calcttest → calcttest.py} +7 -2
- rapidtide/scripts/{ccorrica → ccorrica.py} +7 -2
- rapidtide/scripts/delayvar.py +28 -0
- rapidtide/scripts/{diffrois → diffrois.py} +7 -2
- rapidtide/scripts/{endtidalproc → endtidalproc.py} +7 -2
- rapidtide/scripts/{fdica → fdica.py} +7 -2
- rapidtide/scripts/{filtnifti → filtnifti.py} +7 -2
- rapidtide/scripts/{filttc → filttc.py} +7 -2
- rapidtide/scripts/{fingerprint → fingerprint.py} +20 -16
- rapidtide/scripts/{fixtr → fixtr.py} +7 -2
- rapidtide/scripts/{gmscalc → gmscalc.py} +7 -2
- rapidtide/scripts/{happy → happy.py} +7 -2
- rapidtide/scripts/{happy2std → happy2std.py} +7 -2
- rapidtide/scripts/{happywarp → happywarp.py} +8 -4
- rapidtide/scripts/{histnifti → histnifti.py} +7 -2
- rapidtide/scripts/{histtc → histtc.py} +7 -2
- rapidtide/scripts/{glmfilt → linfitfilt.py} +7 -4
- rapidtide/scripts/{localflow → localflow.py} +7 -2
- rapidtide/scripts/{mergequality → mergequality.py} +7 -2
- rapidtide/scripts/{pairproc → pairproc.py} +7 -2
- rapidtide/scripts/{pairwisemergenifti → pairwisemergenifti.py} +7 -2
- rapidtide/scripts/{physiofreq → physiofreq.py} +7 -2
- rapidtide/scripts/{pixelcomp → pixelcomp.py} +7 -2
- rapidtide/scripts/{plethquality → plethquality.py} +7 -2
- rapidtide/scripts/{polyfitim → polyfitim.py} +7 -2
- rapidtide/scripts/{proj2flow → proj2flow.py} +7 -2
- rapidtide/scripts/{rankimage → rankimage.py} +7 -2
- rapidtide/scripts/{rapidtide → rapidtide.py} +7 -2
- rapidtide/scripts/{rapidtide2std → rapidtide2std.py} +7 -2
- rapidtide/scripts/{resamplenifti → resamplenifti.py} +7 -2
- rapidtide/scripts/{resampletc → resampletc.py} +7 -2
- rapidtide/scripts/retrolagtcs.py +28 -0
- rapidtide/scripts/retroregress.py +28 -0
- rapidtide/scripts/{roisummarize → roisummarize.py} +7 -2
- rapidtide/scripts/{runqualitycheck → runqualitycheck.py} +7 -2
- rapidtide/scripts/{showarbcorr → showarbcorr.py} +7 -2
- rapidtide/scripts/{showhist → showhist.py} +7 -2
- rapidtide/scripts/{showstxcorr → showstxcorr.py} +7 -2
- rapidtide/scripts/{showtc → showtc.py} +7 -2
- rapidtide/scripts/{showxcorr_legacy → showxcorr_legacy.py} +8 -8
- rapidtide/scripts/{showxcorrx → showxcorrx.py} +7 -2
- rapidtide/scripts/{showxy → showxy.py} +7 -2
- rapidtide/scripts/{simdata → simdata.py} +7 -2
- rapidtide/scripts/{spatialdecomp → spatialdecomp.py} +7 -2
- rapidtide/scripts/{spatialfit → spatialfit.py} +7 -2
- rapidtide/scripts/{spatialmi → spatialmi.py} +7 -2
- rapidtide/scripts/{spectrogram → spectrogram.py} +7 -2
- rapidtide/scripts/stupidramtricks.py +238 -0
- rapidtide/scripts/{synthASL → synthASL.py} +7 -2
- rapidtide/scripts/{tcfrom2col → tcfrom2col.py} +7 -2
- rapidtide/scripts/{tcfrom3col → tcfrom3col.py} +7 -2
- rapidtide/scripts/{temporaldecomp → temporaldecomp.py} +7 -2
- rapidtide/scripts/{testhrv → testhrv.py} +1 -1
- rapidtide/scripts/{threeD → threeD.py} +7 -2
- rapidtide/scripts/{tidepool → tidepool.py} +7 -2
- rapidtide/scripts/{variabilityizer → variabilityizer.py} +7 -2
- rapidtide/simFuncClasses.py +2113 -0
- rapidtide/simfuncfit.py +312 -108
- rapidtide/stats.py +579 -247
- rapidtide/tests/.coveragerc +27 -6
- rapidtide-2.9.6.data/scripts/fdica → rapidtide/tests/cleanposttest +4 -6
- rapidtide/tests/happycomp +9 -0
- rapidtide/tests/resethappytargets +1 -1
- rapidtide/tests/resetrapidtidetargets +1 -1
- rapidtide/tests/resettargets +1 -1
- rapidtide/tests/runlocaltest +3 -3
- rapidtide/tests/showkernels +1 -1
- rapidtide/tests/test_aliasedcorrelate.py +4 -4
- rapidtide/tests/test_aligntcs.py +1 -1
- rapidtide/tests/test_calcicc.py +1 -1
- rapidtide/tests/test_cleanregressor.py +184 -0
- rapidtide/tests/test_congrid.py +70 -81
- rapidtide/tests/test_correlate.py +1 -1
- rapidtide/tests/test_corrpass.py +4 -4
- rapidtide/tests/test_delayestimation.py +54 -59
- rapidtide/tests/test_dlfiltertorch.py +437 -0
- rapidtide/tests/test_doresample.py +2 -2
- rapidtide/tests/test_externaltools.py +69 -0
- rapidtide/tests/test_fastresampler.py +9 -5
- rapidtide/tests/test_filter.py +96 -57
- rapidtide/tests/test_findmaxlag.py +50 -19
- rapidtide/tests/test_fullrunhappy_v1.py +15 -10
- rapidtide/tests/test_fullrunhappy_v2.py +19 -13
- rapidtide/tests/test_fullrunhappy_v3.py +28 -13
- rapidtide/tests/test_fullrunhappy_v4.py +30 -11
- rapidtide/tests/test_fullrunhappy_v5.py +62 -0
- rapidtide/tests/test_fullrunrapidtide_v1.py +61 -7
- rapidtide/tests/test_fullrunrapidtide_v2.py +26 -14
- rapidtide/tests/test_fullrunrapidtide_v3.py +28 -8
- rapidtide/tests/test_fullrunrapidtide_v4.py +16 -8
- rapidtide/tests/test_fullrunrapidtide_v5.py +15 -6
- rapidtide/tests/test_fullrunrapidtide_v6.py +142 -0
- rapidtide/tests/test_fullrunrapidtide_v7.py +114 -0
- rapidtide/tests/test_fullrunrapidtide_v8.py +66 -0
- rapidtide/tests/test_getparsers.py +158 -0
- rapidtide/tests/test_io.py +59 -18
- rapidtide/tests/{test_glmpass.py → test_linfitfiltpass.py} +10 -10
- rapidtide/tests/test_mi.py +1 -1
- rapidtide/tests/test_miscmath.py +1 -1
- rapidtide/tests/test_motionregress.py +5 -5
- rapidtide/tests/test_nullcorr.py +6 -9
- rapidtide/tests/test_padvec.py +216 -0
- rapidtide/tests/test_parserfuncs.py +101 -0
- rapidtide/tests/test_phaseanalysis.py +1 -1
- rapidtide/tests/test_rapidtideparser.py +59 -53
- rapidtide/tests/test_refinedelay.py +296 -0
- rapidtide/tests/test_runmisc.py +5 -5
- rapidtide/tests/test_sharedmem.py +60 -0
- rapidtide/tests/test_simroundtrip.py +132 -0
- rapidtide/tests/test_simulate.py +1 -1
- rapidtide/tests/test_stcorrelate.py +4 -2
- rapidtide/tests/test_timeshift.py +2 -2
- rapidtide/tests/test_valtoindex.py +1 -1
- rapidtide/tests/test_zRapidtideDataset.py +5 -3
- rapidtide/tests/utils.py +10 -9
- rapidtide/tidepoolTemplate.py +88 -70
- rapidtide/tidepoolTemplate.ui +60 -46
- rapidtide/tidepoolTemplate_alt.py +88 -53
- rapidtide/tidepoolTemplate_alt.ui +62 -52
- rapidtide/tidepoolTemplate_alt_qt6.py +921 -0
- rapidtide/tidepoolTemplate_big.py +1125 -0
- rapidtide/tidepoolTemplate_big.ui +2386 -0
- rapidtide/tidepoolTemplate_big_qt6.py +1129 -0
- rapidtide/tidepoolTemplate_qt6.py +793 -0
- rapidtide/util.py +1389 -148
- rapidtide/voxelData.py +1048 -0
- rapidtide/wiener.py +138 -25
- rapidtide/wiener2.py +114 -8
- rapidtide/workflows/adjustoffset.py +107 -5
- rapidtide/workflows/aligntcs.py +86 -3
- rapidtide/workflows/applydlfilter.py +231 -89
- rapidtide/workflows/applyppgproc.py +540 -0
- rapidtide/workflows/atlasaverage.py +309 -48
- rapidtide/workflows/atlastool.py +130 -9
- rapidtide/workflows/calcSimFuncMap.py +490 -0
- rapidtide/workflows/calctexticc.py +202 -10
- rapidtide/workflows/ccorrica.py +123 -15
- rapidtide/workflows/cleanregressor.py +415 -0
- rapidtide/workflows/delayvar.py +1268 -0
- rapidtide/workflows/diffrois.py +84 -6
- rapidtide/workflows/endtidalproc.py +149 -9
- rapidtide/workflows/fdica.py +197 -17
- rapidtide/workflows/filtnifti.py +71 -4
- rapidtide/workflows/filttc.py +76 -5
- rapidtide/workflows/fitSimFuncMap.py +578 -0
- rapidtide/workflows/fixtr.py +74 -4
- rapidtide/workflows/gmscalc.py +116 -6
- rapidtide/workflows/happy.py +1242 -480
- rapidtide/workflows/happy2std.py +145 -13
- rapidtide/workflows/happy_parser.py +277 -59
- rapidtide/workflows/histnifti.py +120 -4
- rapidtide/workflows/histtc.py +85 -4
- rapidtide/workflows/{glmfilt.py → linfitfilt.py} +128 -14
- rapidtide/workflows/localflow.py +329 -29
- rapidtide/workflows/mergequality.py +80 -4
- rapidtide/workflows/niftidecomp.py +323 -19
- rapidtide/workflows/niftistats.py +178 -8
- rapidtide/workflows/pairproc.py +99 -5
- rapidtide/workflows/pairwisemergenifti.py +86 -3
- rapidtide/workflows/parser_funcs.py +1488 -56
- rapidtide/workflows/physiofreq.py +139 -12
- rapidtide/workflows/pixelcomp.py +211 -9
- rapidtide/workflows/plethquality.py +105 -23
- rapidtide/workflows/polyfitim.py +159 -19
- rapidtide/workflows/proj2flow.py +76 -3
- rapidtide/workflows/rankimage.py +115 -8
- rapidtide/workflows/rapidtide.py +1785 -1858
- rapidtide/workflows/rapidtide2std.py +101 -3
- rapidtide/workflows/rapidtide_parser.py +590 -389
- rapidtide/workflows/refineDelayMap.py +249 -0
- rapidtide/workflows/refineRegressor.py +1215 -0
- rapidtide/workflows/regressfrommaps.py +308 -0
- rapidtide/workflows/resamplenifti.py +86 -4
- rapidtide/workflows/resampletc.py +92 -4
- rapidtide/workflows/retrolagtcs.py +442 -0
- rapidtide/workflows/retroregress.py +1501 -0
- rapidtide/workflows/roisummarize.py +176 -7
- rapidtide/workflows/runqualitycheck.py +72 -7
- rapidtide/workflows/showarbcorr.py +172 -16
- rapidtide/workflows/showhist.py +87 -3
- rapidtide/workflows/showstxcorr.py +161 -4
- rapidtide/workflows/showtc.py +172 -10
- rapidtide/workflows/showxcorrx.py +250 -62
- rapidtide/workflows/showxy.py +186 -16
- rapidtide/workflows/simdata.py +418 -112
- rapidtide/workflows/spatialfit.py +83 -8
- rapidtide/workflows/spatialmi.py +252 -29
- rapidtide/workflows/spectrogram.py +306 -33
- rapidtide/workflows/synthASL.py +157 -6
- rapidtide/workflows/tcfrom2col.py +77 -3
- rapidtide/workflows/tcfrom3col.py +75 -3
- rapidtide/workflows/tidepool.py +3829 -666
- rapidtide/workflows/utils.py +45 -19
- rapidtide/workflows/utils_doc.py +293 -0
- rapidtide/workflows/variabilityizer.py +118 -5
- {rapidtide-2.9.6.dist-info → rapidtide-3.1.3.dist-info}/METADATA +30 -223
- rapidtide-3.1.3.dist-info/RECORD +393 -0
- {rapidtide-2.9.6.dist-info → rapidtide-3.1.3.dist-info}/WHEEL +1 -1
- rapidtide-3.1.3.dist-info/entry_points.txt +65 -0
- rapidtide-3.1.3.dist-info/top_level.txt +2 -0
- rapidtide/calcandfitcorrpairs.py +0 -262
- rapidtide/data/examples/src/testoutputsize +0 -45
- rapidtide/data/models/model_revised/model.h5 +0 -0
- rapidtide/data/models/model_serdar/model.h5 +0 -0
- rapidtide/data/models/model_serdar2/model.h5 +0 -0
- rapidtide/data/reference/ASPECTS_nlin_asym_09c_2mm.nii.gz +0 -0
- rapidtide/data/reference/ASPECTS_nlin_asym_09c_2mm_mask.nii.gz +0 -0
- rapidtide/data/reference/ATTbasedFlowTerritories_split_nlin_asym_09c_2mm.nii.gz +0 -0
- rapidtide/data/reference/ATTbasedFlowTerritories_split_nlin_asym_09c_2mm_mask.nii.gz +0 -0
- rapidtide/data/reference/HCP1200_binmask_2mm_2009c_asym.nii.gz +0 -0
- rapidtide/data/reference/HCP1200_lag_2mm_2009c_asym.nii.gz +0 -0
- rapidtide/data/reference/HCP1200_mask_2mm_2009c_asym.nii.gz +0 -0
- rapidtide/data/reference/HCP1200_negmask_2mm_2009c_asym.nii.gz +0 -0
- rapidtide/data/reference/HCP1200_sigma_2mm_2009c_asym.nii.gz +0 -0
- rapidtide/data/reference/HCP1200_strength_2mm_2009c_asym.nii.gz +0 -0
- rapidtide/glmpass.py +0 -434
- rapidtide/refine_factored.py +0 -641
- rapidtide/scripts/retroglm +0 -23
- rapidtide/workflows/glmfrommaps.py +0 -202
- rapidtide/workflows/retroglm.py +0 -643
- rapidtide-2.9.6.data/scripts/adjustoffset +0 -23
- rapidtide-2.9.6.data/scripts/aligntcs +0 -23
- rapidtide-2.9.6.data/scripts/applydlfilter +0 -23
- rapidtide-2.9.6.data/scripts/atlasaverage +0 -23
- rapidtide-2.9.6.data/scripts/atlastool +0 -23
- rapidtide-2.9.6.data/scripts/calcicc +0 -22
- rapidtide-2.9.6.data/scripts/calctexticc +0 -23
- rapidtide-2.9.6.data/scripts/calcttest +0 -22
- rapidtide-2.9.6.data/scripts/ccorrica +0 -23
- rapidtide-2.9.6.data/scripts/diffrois +0 -23
- rapidtide-2.9.6.data/scripts/endtidalproc +0 -23
- rapidtide-2.9.6.data/scripts/filtnifti +0 -23
- rapidtide-2.9.6.data/scripts/filttc +0 -23
- rapidtide-2.9.6.data/scripts/fingerprint +0 -593
- rapidtide-2.9.6.data/scripts/fixtr +0 -23
- rapidtide-2.9.6.data/scripts/glmfilt +0 -24
- rapidtide-2.9.6.data/scripts/gmscalc +0 -22
- rapidtide-2.9.6.data/scripts/happy +0 -25
- rapidtide-2.9.6.data/scripts/happy2std +0 -23
- rapidtide-2.9.6.data/scripts/happywarp +0 -350
- rapidtide-2.9.6.data/scripts/histnifti +0 -23
- rapidtide-2.9.6.data/scripts/histtc +0 -23
- rapidtide-2.9.6.data/scripts/localflow +0 -23
- rapidtide-2.9.6.data/scripts/mergequality +0 -23
- rapidtide-2.9.6.data/scripts/pairproc +0 -23
- rapidtide-2.9.6.data/scripts/pairwisemergenifti +0 -23
- rapidtide-2.9.6.data/scripts/physiofreq +0 -23
- rapidtide-2.9.6.data/scripts/pixelcomp +0 -23
- rapidtide-2.9.6.data/scripts/plethquality +0 -23
- rapidtide-2.9.6.data/scripts/polyfitim +0 -23
- rapidtide-2.9.6.data/scripts/proj2flow +0 -23
- rapidtide-2.9.6.data/scripts/rankimage +0 -23
- rapidtide-2.9.6.data/scripts/rapidtide +0 -23
- rapidtide-2.9.6.data/scripts/rapidtide2std +0 -23
- rapidtide-2.9.6.data/scripts/resamplenifti +0 -23
- rapidtide-2.9.6.data/scripts/resampletc +0 -23
- rapidtide-2.9.6.data/scripts/retroglm +0 -23
- rapidtide-2.9.6.data/scripts/roisummarize +0 -23
- rapidtide-2.9.6.data/scripts/runqualitycheck +0 -23
- rapidtide-2.9.6.data/scripts/showarbcorr +0 -23
- rapidtide-2.9.6.data/scripts/showhist +0 -23
- rapidtide-2.9.6.data/scripts/showstxcorr +0 -23
- rapidtide-2.9.6.data/scripts/showtc +0 -23
- rapidtide-2.9.6.data/scripts/showxcorr_legacy +0 -536
- rapidtide-2.9.6.data/scripts/showxcorrx +0 -23
- rapidtide-2.9.6.data/scripts/showxy +0 -23
- rapidtide-2.9.6.data/scripts/simdata +0 -23
- rapidtide-2.9.6.data/scripts/spatialdecomp +0 -23
- rapidtide-2.9.6.data/scripts/spatialfit +0 -23
- rapidtide-2.9.6.data/scripts/spatialmi +0 -23
- rapidtide-2.9.6.data/scripts/spectrogram +0 -23
- rapidtide-2.9.6.data/scripts/synthASL +0 -23
- rapidtide-2.9.6.data/scripts/tcfrom2col +0 -23
- rapidtide-2.9.6.data/scripts/tcfrom3col +0 -23
- rapidtide-2.9.6.data/scripts/temporaldecomp +0 -23
- rapidtide-2.9.6.data/scripts/threeD +0 -236
- rapidtide-2.9.6.data/scripts/tidepool +0 -23
- rapidtide-2.9.6.data/scripts/variabilityizer +0 -23
- rapidtide-2.9.6.dist-info/RECORD +0 -359
- rapidtide-2.9.6.dist-info/top_level.txt +0 -86
- {rapidtide-2.9.6.dist-info → rapidtide-3.1.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
3
|
#
|
|
4
|
-
# Copyright 2016-
|
|
4
|
+
# Copyright 2016-2025 Blaise Frederick
|
|
5
5
|
#
|
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
7
|
# you may not use this file except in compliance with the License.
|
|
@@ -24,12 +24,13 @@ import numpy as np
|
|
|
24
24
|
|
|
25
25
|
import rapidtide.calcsimfunc as tide_calcsimfunc
|
|
26
26
|
import rapidtide.filter as tide_filt
|
|
27
|
-
import rapidtide.
|
|
28
|
-
import rapidtide.helper_classes as tide_classes
|
|
27
|
+
import rapidtide.linfitfiltpass as tide_linfitfiltpass
|
|
29
28
|
import rapidtide.miscmath as tide_math
|
|
30
29
|
import rapidtide.peakeval as tide_peakeval
|
|
31
30
|
import rapidtide.resample as tide_resample
|
|
31
|
+
import rapidtide.simFuncClasses as tide_simFuncClasses
|
|
32
32
|
import rapidtide.simfuncfit as tide_simfuncfit
|
|
33
|
+
import rapidtide.util as tide_util
|
|
33
34
|
|
|
34
35
|
try:
|
|
35
36
|
import mkl
|
|
@@ -39,34 +40,6 @@ except ImportError:
|
|
|
39
40
|
mklexists = False
|
|
40
41
|
|
|
41
42
|
|
|
42
|
-
def numpy2shared(inarray, thetype):
|
|
43
|
-
thesize = inarray.size
|
|
44
|
-
theshape = inarray.shape
|
|
45
|
-
if thetype == np.float64:
|
|
46
|
-
inarray_shared = mp.RawArray("d", inarray.reshape(thesize))
|
|
47
|
-
else:
|
|
48
|
-
inarray_shared = mp.RawArray("f", inarray.reshape(thesize))
|
|
49
|
-
inarray = np.frombuffer(inarray_shared, dtype=thetype, count=thesize)
|
|
50
|
-
inarray.shape = theshape
|
|
51
|
-
return inarray
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def allocshared(theshape, thetype):
|
|
55
|
-
thesize = int(1)
|
|
56
|
-
if not isinstance(theshape, (list, tuple)):
|
|
57
|
-
thesize = theshape
|
|
58
|
-
else:
|
|
59
|
-
for element in theshape:
|
|
60
|
-
thesize *= int(element)
|
|
61
|
-
if thetype == np.float64:
|
|
62
|
-
outarray_shared = mp.RawArray("d", thesize)
|
|
63
|
-
else:
|
|
64
|
-
outarray_shared = mp.RawArray("f", thesize)
|
|
65
|
-
outarray = np.frombuffer(outarray_shared, dtype=thetype, count=thesize)
|
|
66
|
-
outarray.shape = theshape
|
|
67
|
-
return outarray, outarray_shared, theshape
|
|
68
|
-
|
|
69
|
-
|
|
70
43
|
def multisine(timepoints, parameterlist):
|
|
71
44
|
output = timepoints * 0.0
|
|
72
45
|
for element in parameterlist:
|
|
@@ -144,7 +117,7 @@ def test_delayestimation(displayplots=False, debug=False):
|
|
|
144
117
|
plt.show()
|
|
145
118
|
|
|
146
119
|
threshval = pedestal / 4.0
|
|
147
|
-
waveforms = numpy2shared(waveforms, np.float64)
|
|
120
|
+
waveforms, waveforms_shm = tide_util.numpy2shared(waveforms, np.float64)
|
|
148
121
|
|
|
149
122
|
referencetc = tide_resample.doresample(
|
|
150
123
|
timepoints, waveforms[refnum, :], oversamptimepoints, method=interptype
|
|
@@ -157,7 +130,7 @@ def test_delayestimation(displayplots=False, debug=False):
|
|
|
157
130
|
# set up theCorrelator
|
|
158
131
|
if debug:
|
|
159
132
|
print("\n\nsetting up theCorrelator")
|
|
160
|
-
theCorrelator =
|
|
133
|
+
theCorrelator = tide_simFuncClasses.Correlator(
|
|
161
134
|
Fs=oversampfreq,
|
|
162
135
|
ncprefilter=theprefilter,
|
|
163
136
|
detrendorder=detrendorder,
|
|
@@ -170,8 +143,8 @@ def test_delayestimation(displayplots=False, debug=False):
|
|
|
170
143
|
dummy, trimmedcorrscale, dummy = theCorrelator.getfunction()
|
|
171
144
|
corroutlen = np.shape(trimmedcorrscale)[0]
|
|
172
145
|
internalvalidcorrshape = (numlocs, corroutlen)
|
|
173
|
-
corrout,
|
|
174
|
-
meanval,
|
|
146
|
+
corrout, corrout_shm = tide_util.allocshared(internalvalidcorrshape, np.float64)
|
|
147
|
+
meanval, meanval_shm = tide_util.allocshared((numlocs), np.float64)
|
|
175
148
|
if debug:
|
|
176
149
|
print("corrout shape:", corrout.shape)
|
|
177
150
|
print("theCorrelator: corroutlen=", corroutlen)
|
|
@@ -179,7 +152,7 @@ def test_delayestimation(displayplots=False, debug=False):
|
|
|
179
152
|
# set up theMutualInformationator
|
|
180
153
|
if debug:
|
|
181
154
|
print("\n\nsetting up theMutualInformationator")
|
|
182
|
-
theMutualInformationator =
|
|
155
|
+
theMutualInformationator = tide_simFuncClasses.MutualInformationator(
|
|
183
156
|
Fs=oversampfreq,
|
|
184
157
|
smoothingtime=smoothingtime,
|
|
185
158
|
ncprefilter=theprefilter,
|
|
@@ -197,7 +170,7 @@ def test_delayestimation(displayplots=False, debug=False):
|
|
|
197
170
|
# set up thefitter
|
|
198
171
|
if debug:
|
|
199
172
|
print("\n\nsetting up thefitter")
|
|
200
|
-
thefitter =
|
|
173
|
+
thefitter = tide_simFuncClasses.SimilarityFunctionFitter(
|
|
201
174
|
lagmod=lagmod,
|
|
202
175
|
lthreshval=0.0,
|
|
203
176
|
uthreshval=1.0,
|
|
@@ -210,21 +183,21 @@ def test_delayestimation(displayplots=False, debug=False):
|
|
|
210
183
|
peakfittype=peakfittype,
|
|
211
184
|
)
|
|
212
185
|
|
|
213
|
-
lagtc,
|
|
214
|
-
fitmask,
|
|
215
|
-
failreason,
|
|
216
|
-
lagtimes,
|
|
217
|
-
lagstrengths,
|
|
218
|
-
lagsigma,
|
|
219
|
-
gaussout,
|
|
220
|
-
windowout,
|
|
221
|
-
rvalue,
|
|
222
|
-
r2value,
|
|
223
|
-
fitcoff,
|
|
224
|
-
fitNorm,
|
|
225
|
-
R2,
|
|
226
|
-
movingsignal,
|
|
227
|
-
filtereddata,
|
|
186
|
+
lagtc, lagtc_shm = tide_util.allocshared(waveforms.shape, np.float64)
|
|
187
|
+
fitmask, fitmask_shm = tide_util.allocshared((numlocs), "uint16")
|
|
188
|
+
failreason, failreason_shm = tide_util.allocshared((numlocs), "uint32")
|
|
189
|
+
lagtimes, lagtimes_shm = tide_util.allocshared((numlocs), np.float64)
|
|
190
|
+
lagstrengths, lagstrengths_shm = tide_util.allocshared((numlocs), np.float64)
|
|
191
|
+
lagsigma, lagsigma_shm = tide_util.allocshared((numlocs), np.float64)
|
|
192
|
+
gaussout, gaussout_shm = tide_util.allocshared(internalvalidcorrshape, np.float64)
|
|
193
|
+
windowout, windowout_shm = tide_util.allocshared(internalvalidcorrshape, np.float64)
|
|
194
|
+
rvalue, rvalue_shm = tide_util.allocshared((numlocs), np.float64)
|
|
195
|
+
r2value, r2value_shm = tide_util.allocshared((numlocs), np.float64)
|
|
196
|
+
fitcoff, fitcoff_shm = tide_util.allocshared((waveforms.shape), np.float64)
|
|
197
|
+
fitNorm, fitNorm_shm = tide_util.allocshared((waveforms.shape), np.float64)
|
|
198
|
+
R2, R2_shm = tide_util.allocshared((numlocs), np.float64)
|
|
199
|
+
movingsignal, movingsignal_shm = tide_util.allocshared(waveforms.shape, np.float64)
|
|
200
|
+
filtereddata, filtereddata_shm = tide_util.allocshared(waveforms.shape, np.float64)
|
|
228
201
|
|
|
229
202
|
for nprocs in [4, 1]:
|
|
230
203
|
# call correlationpass
|
|
@@ -357,12 +330,12 @@ def test_delayestimation(displayplots=False, debug=False):
|
|
|
357
330
|
ax.legend()
|
|
358
331
|
plt.show()
|
|
359
332
|
|
|
360
|
-
filteredwaveforms,
|
|
333
|
+
filteredwaveforms, filteredwaveforms_shm = tide_util.allocshared(waveforms.shape, np.float64)
|
|
361
334
|
for i in range(numlocs):
|
|
362
335
|
filteredwaveforms[i, :] = theprefilter.apply(Fs, waveforms[i, :])
|
|
363
336
|
|
|
364
337
|
for nprocs in [4, 1]:
|
|
365
|
-
|
|
338
|
+
voxelsprocessed_regressionfilt = tide_linfitfiltpass.linfitfiltpass(
|
|
366
339
|
numlocs,
|
|
367
340
|
waveforms[:, :],
|
|
368
341
|
threshval,
|
|
@@ -377,7 +350,7 @@ def test_delayestimation(displayplots=False, debug=False):
|
|
|
377
350
|
nprocs=nprocs,
|
|
378
351
|
alwaysmultiproc=False,
|
|
379
352
|
showprogressbar=False,
|
|
380
|
-
|
|
353
|
+
chunksize=chunksize,
|
|
381
354
|
)
|
|
382
355
|
|
|
383
356
|
if nprocs == 1:
|
|
@@ -387,13 +360,35 @@ def test_delayestimation(displayplots=False, debug=False):
|
|
|
387
360
|
diffsignal = filtereddata
|
|
388
361
|
fig = plt.figure()
|
|
389
362
|
ax = fig.add_subplot(1, 1, 1)
|
|
390
|
-
# ax.plot(
|
|
391
|
-
ax.plot(
|
|
392
|
-
ax.plot(timepoints, movingsignal[refnum, :], label="movingsignal")
|
|
363
|
+
# ax.plot(oversamptimepoints, referencetc, label="referencetc")
|
|
364
|
+
ax.plot(timepoints, waveforms[refnum, :], label="waveform")
|
|
365
|
+
ax.plot(timepoints, fitcoff[refnum] * movingsignal[refnum, :], label="movingsignal")
|
|
366
|
+
# ax.plot(timepoints, filtereddata[refnum, :], label="filtereddata")
|
|
393
367
|
ax.legend()
|
|
394
368
|
plt.show()
|
|
395
369
|
|
|
396
|
-
print(proctype, "
|
|
370
|
+
print(proctype, "linfitfiltpass", np.mean(diffsignal), np.max(np.fabs(diffsignal)))
|
|
371
|
+
|
|
372
|
+
# clean up shared memory
|
|
373
|
+
tide_util.cleanup_shm(waveforms_shm)
|
|
374
|
+
tide_util.cleanup_shm(corrout_shm)
|
|
375
|
+
tide_util.cleanup_shm(meanval_shm)
|
|
376
|
+
tide_util.cleanup_shm(lagtc_shm)
|
|
377
|
+
tide_util.cleanup_shm(fitmask_shm)
|
|
378
|
+
tide_util.cleanup_shm(failreason_shm)
|
|
379
|
+
tide_util.cleanup_shm(lagtimes_shm)
|
|
380
|
+
tide_util.cleanup_shm(lagstrengths_shm)
|
|
381
|
+
tide_util.cleanup_shm(lagsigma_shm)
|
|
382
|
+
tide_util.cleanup_shm(gaussout_shm)
|
|
383
|
+
tide_util.cleanup_shm(windowout_shm)
|
|
384
|
+
tide_util.cleanup_shm(rvalue_shm)
|
|
385
|
+
tide_util.cleanup_shm(r2value_shm)
|
|
386
|
+
tide_util.cleanup_shm(fitcoff_shm)
|
|
387
|
+
tide_util.cleanup_shm(fitNorm_shm)
|
|
388
|
+
tide_util.cleanup_shm(R2_shm)
|
|
389
|
+
tide_util.cleanup_shm(movingsignal_shm)
|
|
390
|
+
tide_util.cleanup_shm(filtereddata_shm)
|
|
391
|
+
tide_util.cleanup_shm(filteredwaveforms_shm)
|
|
397
392
|
|
|
398
393
|
|
|
399
394
|
if __name__ == "__main__":
|
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
#
|
|
4
|
+
# Copyright 2025-2025 Blaise Frederick
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
#
|
|
18
|
+
#
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
import rapidtide.dlfiltertorch as dlfiltertorch
|
|
25
|
+
from rapidtide.tests.utils import get_test_temp_path, mse
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def create_dummy_data():
|
|
29
|
+
"""Create dummy training data for testing."""
|
|
30
|
+
window_size = 64
|
|
31
|
+
num_samples = 100
|
|
32
|
+
|
|
33
|
+
# Create dummy input and output data
|
|
34
|
+
train_x = np.random.randn(num_samples, window_size, 1).astype(np.float32)
|
|
35
|
+
train_y = np.random.randn(num_samples, window_size, 1).astype(np.float32)
|
|
36
|
+
val_x = np.random.randn(20, window_size, 1).astype(np.float32)
|
|
37
|
+
val_y = np.random.randn(20, window_size, 1).astype(np.float32)
|
|
38
|
+
|
|
39
|
+
return {
|
|
40
|
+
"train_x": train_x,
|
|
41
|
+
"train_y": train_y,
|
|
42
|
+
"val_x": val_x,
|
|
43
|
+
"val_y": val_y,
|
|
44
|
+
"window_size": window_size,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def cnn_model_creation():
|
|
49
|
+
"""Test CNN model instantiation and forward pass."""
|
|
50
|
+
num_filters = 10
|
|
51
|
+
kernel_size = 5
|
|
52
|
+
num_layers = 3
|
|
53
|
+
dropout_rate = 0.3
|
|
54
|
+
dilation_rate = 1
|
|
55
|
+
activation = "relu"
|
|
56
|
+
inputsize = 1
|
|
57
|
+
|
|
58
|
+
model = dlfiltertorch.CNNModel(
|
|
59
|
+
num_filters=num_filters,
|
|
60
|
+
kernel_size=kernel_size,
|
|
61
|
+
num_layers=num_layers,
|
|
62
|
+
dropout_rate=dropout_rate,
|
|
63
|
+
dilation_rate=dilation_rate,
|
|
64
|
+
activation=activation,
|
|
65
|
+
inputsize=inputsize,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Test forward pass
|
|
69
|
+
batch_size = 4
|
|
70
|
+
seq_len = 64
|
|
71
|
+
x = torch.randn(batch_size, inputsize, seq_len)
|
|
72
|
+
output = model(x)
|
|
73
|
+
|
|
74
|
+
assert output.shape == (batch_size, inputsize, seq_len)
|
|
75
|
+
|
|
76
|
+
# Test get_config
|
|
77
|
+
config = model.get_config()
|
|
78
|
+
assert config["num_filters"] == num_filters
|
|
79
|
+
assert config["kernel_size"] == kernel_size
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def cnn_dlfilter_initialization(testtemproot):
|
|
83
|
+
"""Test CNNDLFilter initialization."""
|
|
84
|
+
filter_obj = dlfiltertorch.CNNDLFilter(
|
|
85
|
+
num_filters=10,
|
|
86
|
+
kernel_size=5,
|
|
87
|
+
window_size=64,
|
|
88
|
+
num_layers=3,
|
|
89
|
+
num_epochs=1,
|
|
90
|
+
modelroot=testtemproot,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
assert filter_obj.window_size == 64
|
|
94
|
+
assert filter_obj.num_filters == 10
|
|
95
|
+
assert filter_obj.kernel_size == 5
|
|
96
|
+
assert filter_obj.nettype == "cnn"
|
|
97
|
+
assert not filter_obj.initialized
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def cnn_dlfilter_initialize(testtemproot):
|
|
101
|
+
"""Test CNNDLFilter model initialization."""
|
|
102
|
+
filter_obj = dlfiltertorch.CNNDLFilter(
|
|
103
|
+
num_filters=10,
|
|
104
|
+
kernel_size=5,
|
|
105
|
+
window_size=64,
|
|
106
|
+
num_layers=3,
|
|
107
|
+
num_epochs=1,
|
|
108
|
+
modelroot=testtemproot,
|
|
109
|
+
namesuffix="test",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Just call getname and makenet, don't call full initialize
|
|
113
|
+
# because savemodel has a bug using modelname instead of modelpath
|
|
114
|
+
filter_obj.getname()
|
|
115
|
+
filter_obj.makenet()
|
|
116
|
+
|
|
117
|
+
assert filter_obj.model is not None
|
|
118
|
+
assert os.path.exists(filter_obj.modelpath)
|
|
119
|
+
|
|
120
|
+
# Manually save using modelpath
|
|
121
|
+
filter_obj.model.to(filter_obj.device)
|
|
122
|
+
filter_obj.savemodel(altname=filter_obj.modelpath)
|
|
123
|
+
|
|
124
|
+
assert os.path.exists(os.path.join(filter_obj.modelpath, "model.pth"))
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def predict_model(testtemproot, dummy_data):
|
|
128
|
+
"""Test the predict_model method."""
|
|
129
|
+
filter_obj = dlfiltertorch.CNNDLFilter(
|
|
130
|
+
num_filters=10,
|
|
131
|
+
kernel_size=5,
|
|
132
|
+
window_size=dummy_data["window_size"],
|
|
133
|
+
num_layers=3,
|
|
134
|
+
num_epochs=1,
|
|
135
|
+
modelroot=testtemproot,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Just create the model without full initialize
|
|
139
|
+
filter_obj.getname()
|
|
140
|
+
filter_obj.makenet()
|
|
141
|
+
filter_obj.model.to(filter_obj.device)
|
|
142
|
+
|
|
143
|
+
# Test prediction with numpy array
|
|
144
|
+
predictions = filter_obj.predict_model(dummy_data["val_x"])
|
|
145
|
+
|
|
146
|
+
assert predictions.shape == dummy_data["val_y"].shape
|
|
147
|
+
assert isinstance(predictions, np.ndarray)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def apply_method(testtemproot):
|
|
151
|
+
"""Test the apply method for filtering a signal."""
|
|
152
|
+
window_size = 64
|
|
153
|
+
signal_length = 500
|
|
154
|
+
|
|
155
|
+
filter_obj = dlfiltertorch.CNNDLFilter(
|
|
156
|
+
num_filters=10,
|
|
157
|
+
kernel_size=5,
|
|
158
|
+
window_size=window_size,
|
|
159
|
+
num_layers=3,
|
|
160
|
+
num_epochs=1,
|
|
161
|
+
modelroot=testtemproot,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Just create the model without full initialize
|
|
165
|
+
filter_obj.getname()
|
|
166
|
+
filter_obj.makenet()
|
|
167
|
+
filter_obj.model.to(filter_obj.device)
|
|
168
|
+
|
|
169
|
+
# Create a test signal
|
|
170
|
+
input_signal = np.random.randn(signal_length).astype(np.float32)
|
|
171
|
+
|
|
172
|
+
# Apply the filter
|
|
173
|
+
filtered_signal = filter_obj.apply(input_signal)
|
|
174
|
+
|
|
175
|
+
assert filtered_signal.shape == input_signal.shape
|
|
176
|
+
assert isinstance(filtered_signal, np.ndarray)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def apply_method_with_badpts(testtemproot):
|
|
180
|
+
"""Test the apply method with bad points."""
|
|
181
|
+
window_size = 64
|
|
182
|
+
signal_length = 500
|
|
183
|
+
|
|
184
|
+
filter_obj = dlfiltertorch.CNNDLFilter(
|
|
185
|
+
num_filters=10,
|
|
186
|
+
kernel_size=5,
|
|
187
|
+
window_size=window_size,
|
|
188
|
+
num_layers=3,
|
|
189
|
+
num_epochs=1,
|
|
190
|
+
modelroot=testtemproot,
|
|
191
|
+
usebadpts=True,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Just create the model without full initialize
|
|
195
|
+
filter_obj.getname()
|
|
196
|
+
filter_obj.makenet()
|
|
197
|
+
filter_obj.model.to(filter_obj.device)
|
|
198
|
+
|
|
199
|
+
# Create test signal and bad points
|
|
200
|
+
input_signal = np.random.randn(signal_length).astype(np.float32)
|
|
201
|
+
badpts = np.zeros(signal_length, dtype=np.float32)
|
|
202
|
+
badpts[100:120] = 1.0 # Mark some points as bad
|
|
203
|
+
|
|
204
|
+
# Apply the filter with bad points
|
|
205
|
+
filtered_signal = filter_obj.apply(input_signal, badpts=badpts)
|
|
206
|
+
|
|
207
|
+
assert filtered_signal.shape == input_signal.shape
|
|
208
|
+
assert isinstance(filtered_signal, np.ndarray)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def save_and_load_model(testtemproot):
|
|
212
|
+
"""Test saving and loading a model."""
|
|
213
|
+
# This test is skipped because both savemodel() and initmetadata()
|
|
214
|
+
# use self.modelname (a relative path) instead of self.modelpath (full path)
|
|
215
|
+
filter_obj = dlfiltertorch.CNNDLFilter(
|
|
216
|
+
num_filters=10,
|
|
217
|
+
kernel_size=5,
|
|
218
|
+
window_size=64,
|
|
219
|
+
num_layers=3,
|
|
220
|
+
num_epochs=1,
|
|
221
|
+
modelroot=testtemproot,
|
|
222
|
+
namesuffix="saveloadtest",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Create and save the model using modelpath
|
|
226
|
+
filter_obj.getname()
|
|
227
|
+
filter_obj.makenet()
|
|
228
|
+
filter_obj.model.to(filter_obj.device)
|
|
229
|
+
filter_obj.initmetadata()
|
|
230
|
+
filter_obj.savemodel(altname=filter_obj.modelpath)
|
|
231
|
+
|
|
232
|
+
original_modelname = os.path.basename(filter_obj.modelpath)
|
|
233
|
+
|
|
234
|
+
# Get original model weights
|
|
235
|
+
original_weights = {}
|
|
236
|
+
for name, param in filter_obj.model.named_parameters():
|
|
237
|
+
original_weights[name] = param.data.clone()
|
|
238
|
+
|
|
239
|
+
# Create new filter object and load the saved model
|
|
240
|
+
filter_obj2 = dlfiltertorch.CNNDLFilter(
|
|
241
|
+
num_filters=10, # These will be overridden by loaded model
|
|
242
|
+
kernel_size=5,
|
|
243
|
+
window_size=64,
|
|
244
|
+
num_layers=3,
|
|
245
|
+
num_epochs=1,
|
|
246
|
+
modelroot=testtemproot,
|
|
247
|
+
modelpath=testtemproot,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
filter_obj2.loadmodel(original_modelname)
|
|
251
|
+
|
|
252
|
+
# Check that metadata was loaded correctly
|
|
253
|
+
assert filter_obj2.window_size == 64
|
|
254
|
+
assert filter_obj2.infodict["nettype"] == "cnn"
|
|
255
|
+
|
|
256
|
+
# Verify weights match
|
|
257
|
+
for name, param in filter_obj2.model.named_parameters():
|
|
258
|
+
assert torch.allclose(original_weights[name], param.data)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def filtscale_forward():
|
|
262
|
+
"""Test filtscale function in forward direction."""
|
|
263
|
+
# filtscale expects 1D data (single timecourse)
|
|
264
|
+
data = np.random.randn(64)
|
|
265
|
+
|
|
266
|
+
# Test without log normalization
|
|
267
|
+
scaled_data, scalefac = dlfiltertorch.filtscale(data, reverse=False, lognormalize=False)
|
|
268
|
+
|
|
269
|
+
assert scaled_data.shape == (64, 2)
|
|
270
|
+
assert isinstance(scalefac, (float, np.floating))
|
|
271
|
+
|
|
272
|
+
# Test with log normalization
|
|
273
|
+
scaled_data_log, scalefac_log = dlfiltertorch.filtscale(data, reverse=False, lognormalize=True)
|
|
274
|
+
|
|
275
|
+
assert scaled_data_log.shape == (64, 2)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def filtscale_reverse():
|
|
279
|
+
"""Test filtscale function in reverse direction."""
|
|
280
|
+
# filtscale expects 1D data (single timecourse)
|
|
281
|
+
data = np.random.randn(64)
|
|
282
|
+
|
|
283
|
+
# Forward then reverse
|
|
284
|
+
scaled_data, scalefac = dlfiltertorch.filtscale(data, reverse=False, lognormalize=False)
|
|
285
|
+
|
|
286
|
+
reconstructed = dlfiltertorch.filtscale(
|
|
287
|
+
scaled_data, scalefac=scalefac, reverse=True, lognormalize=False
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Should reconstruct approximately to original
|
|
291
|
+
assert reconstructed.shape == data.shape
|
|
292
|
+
assert mse(data, reconstructed) < 1.0 # Allow some reconstruction error
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def tobadpts():
|
|
296
|
+
"""Test tobadpts helper function."""
|
|
297
|
+
filename = "test_file.txt"
|
|
298
|
+
result = dlfiltertorch.tobadpts(filename)
|
|
299
|
+
assert result == "test_file_badpts.txt"
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def targettoinput():
|
|
303
|
+
"""Test targettoinput helper function."""
|
|
304
|
+
filename = "test_xyz_file.txt"
|
|
305
|
+
result = dlfiltertorch.targettoinput(filename, targetfrag="xyz", inputfrag="abc")
|
|
306
|
+
assert result == "test_abc_file.txt"
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def model_with_different_activations(testtemproot):
|
|
310
|
+
"""Test models with different activation functions."""
|
|
311
|
+
activations = ["relu", "tanh"]
|
|
312
|
+
|
|
313
|
+
for activation in activations:
|
|
314
|
+
model = dlfiltertorch.CNNModel(
|
|
315
|
+
num_filters=10,
|
|
316
|
+
kernel_size=5,
|
|
317
|
+
num_layers=3,
|
|
318
|
+
dropout_rate=0.3,
|
|
319
|
+
dilation_rate=1,
|
|
320
|
+
activation=activation,
|
|
321
|
+
inputsize=1,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Test forward pass
|
|
325
|
+
x = torch.randn(2, 1, 64)
|
|
326
|
+
output = model(x)
|
|
327
|
+
assert output.shape == x.shape
|
|
328
|
+
|
|
329
|
+
config = model.get_config()
|
|
330
|
+
assert config["activation"] == activation
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def device_selection():
|
|
334
|
+
"""Test that device is properly set based on availability."""
|
|
335
|
+
# This test just checks that the device variable is set
|
|
336
|
+
# We can't guarantee CUDA/MPS availability in test environment
|
|
337
|
+
assert dlfiltertorch.device in [torch.device("cuda"), torch.device("mps"), torch.device("cpu")]
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def infodict_population(testtemproot):
|
|
341
|
+
"""Test that infodict is properly populated."""
|
|
342
|
+
filter_obj = dlfiltertorch.CNNDLFilter(
|
|
343
|
+
num_filters=10,
|
|
344
|
+
kernel_size=5,
|
|
345
|
+
window_size=64,
|
|
346
|
+
num_layers=3,
|
|
347
|
+
dropout_rate=0.3,
|
|
348
|
+
num_epochs=5,
|
|
349
|
+
excludethresh=4.0,
|
|
350
|
+
corrthresh_rp=0.5,
|
|
351
|
+
corrthresh_pp=0.9,
|
|
352
|
+
modelroot=testtemproot,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Check that infodict has expected keys
|
|
356
|
+
assert "nettype" in filter_obj.infodict
|
|
357
|
+
assert "num_filters" in filter_obj.infodict
|
|
358
|
+
assert "kernel_size" in filter_obj.infodict
|
|
359
|
+
assert filter_obj.infodict["nettype"] == "cnn"
|
|
360
|
+
|
|
361
|
+
# Create the model (don't call initmetadata due to path bug)
|
|
362
|
+
filter_obj.getname()
|
|
363
|
+
filter_obj.makenet()
|
|
364
|
+
|
|
365
|
+
# The model should populate infodict with window_size during getname
|
|
366
|
+
assert "window_size" in filter_obj.infodict
|
|
367
|
+
assert filter_obj.infodict["window_size"] == 64
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def test_dlfilterops(debug=False, local=False):
|
|
371
|
+
# set input and output directories
|
|
372
|
+
if local:
|
|
373
|
+
testtemproot = "./tmp"
|
|
374
|
+
else:
|
|
375
|
+
testtemproot = get_test_temp_path()
|
|
376
|
+
|
|
377
|
+
thedummydata = create_dummy_data()
|
|
378
|
+
|
|
379
|
+
if debug:
|
|
380
|
+
print("cnn_model_creation()")
|
|
381
|
+
cnn_model_creation()
|
|
382
|
+
|
|
383
|
+
if debug:
|
|
384
|
+
print("cnn_dlfilter_initialization(testtemproot)")
|
|
385
|
+
cnn_dlfilter_initialization(testtemproot)
|
|
386
|
+
|
|
387
|
+
if debug:
|
|
388
|
+
print("cnn_dlfilter_initialize(testtemproot)")
|
|
389
|
+
cnn_dlfilter_initialize(testtemproot)
|
|
390
|
+
|
|
391
|
+
if debug:
|
|
392
|
+
print("predict_model(testtemproot, thedummydata)")
|
|
393
|
+
predict_model(testtemproot, thedummydata)
|
|
394
|
+
|
|
395
|
+
if debug:
|
|
396
|
+
print("apply_method(testtemproot)")
|
|
397
|
+
apply_method(testtemproot)
|
|
398
|
+
|
|
399
|
+
if debug:
|
|
400
|
+
print("apply_method_with_badpts(testtemproot)")
|
|
401
|
+
apply_method_with_badpts(testtemproot)
|
|
402
|
+
|
|
403
|
+
if debug:
|
|
404
|
+
print("save_and_load_model(testtemproot)")
|
|
405
|
+
save_and_load_model(testtemproot)
|
|
406
|
+
|
|
407
|
+
if debug:
|
|
408
|
+
print("filtscale_forward()")
|
|
409
|
+
filtscale_forward()
|
|
410
|
+
|
|
411
|
+
if debug:
|
|
412
|
+
print("filtscale_reverse()")
|
|
413
|
+
filtscale_reverse()
|
|
414
|
+
|
|
415
|
+
if debug:
|
|
416
|
+
print("tobadpts()")
|
|
417
|
+
tobadpts()
|
|
418
|
+
|
|
419
|
+
if debug:
|
|
420
|
+
print("targettoinput()")
|
|
421
|
+
targettoinput()
|
|
422
|
+
|
|
423
|
+
if debug:
|
|
424
|
+
print("model_with_different_activations(testtemproot)")
|
|
425
|
+
model_with_different_activations(testtemproot)
|
|
426
|
+
|
|
427
|
+
if debug:
|
|
428
|
+
print("device_selection()")
|
|
429
|
+
device_selection()
|
|
430
|
+
|
|
431
|
+
if debug:
|
|
432
|
+
print("infodict_population(testtemproot)")
|
|
433
|
+
infodict_population(testtemproot)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
if __name__ == "__main__":
|
|
437
|
+
test_dlfilterops(debug=True, local=True)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
3
|
#
|
|
4
|
-
# Copyright 2016-
|
|
4
|
+
# Copyright 2016-2025 Blaise Frederick
|
|
5
5
|
#
|
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
7
|
# you may not use this file except in compliance with the License.
|
|
@@ -32,7 +32,7 @@ def test_doresample(debug=False):
|
|
|
32
32
|
shiftdist = 30
|
|
33
33
|
timeaxis = np.arange(0.0, 1.0 * testlen) * tr
|
|
34
34
|
# timecoursein = np.zeros((testlen), dtype='float64')
|
|
35
|
-
timecoursein = np.
|
|
35
|
+
timecoursein = np.zeros_like(timeaxis, np.float64)
|
|
36
36
|
midpoint = int(testlen // 2) + 1
|
|
37
37
|
timecoursein[midpoint - 1] = np.float64(1.0)
|
|
38
38
|
timecoursein[midpoint] = np.float64(1.0)
|