rapidtide 2.9.6__py3-none-any.whl → 3.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (405) hide show
  1. cloud/gmscalc-HCPYA +1 -1
  2. cloud/mount-and-run +2 -0
  3. cloud/rapidtide-HCPYA +3 -3
  4. rapidtide/Colortables.py +538 -38
  5. rapidtide/OrthoImageItem.py +1094 -51
  6. rapidtide/RapidtideDataset.py +1709 -114
  7. rapidtide/__init__.py +0 -8
  8. rapidtide/_version.py +4 -4
  9. rapidtide/calccoherence.py +242 -97
  10. rapidtide/calcnullsimfunc.py +240 -140
  11. rapidtide/calcsimfunc.py +314 -129
  12. rapidtide/correlate.py +1211 -389
  13. rapidtide/data/examples/src/testLD +56 -0
  14. rapidtide/data/examples/src/test_findmaxlag.py +2 -2
  15. rapidtide/data/examples/src/test_mlregressallt.py +32 -17
  16. rapidtide/data/examples/src/testalign +1 -1
  17. rapidtide/data/examples/src/testatlasaverage +35 -7
  18. rapidtide/data/examples/src/testboth +21 -0
  19. rapidtide/data/examples/src/testcifti +11 -0
  20. rapidtide/data/examples/src/testdelayvar +13 -0
  21. rapidtide/data/examples/src/testdlfilt +25 -0
  22. rapidtide/data/examples/src/testfft +35 -0
  23. rapidtide/data/examples/src/testfileorfloat +37 -0
  24. rapidtide/data/examples/src/testfmri +92 -42
  25. rapidtide/data/examples/src/testfuncs +3 -3
  26. rapidtide/data/examples/src/testglmfilt +8 -6
  27. rapidtide/data/examples/src/testhappy +84 -51
  28. rapidtide/data/examples/src/testinitdelay +19 -0
  29. rapidtide/data/examples/src/testmodels +33 -0
  30. rapidtide/data/examples/src/testnewrefine +26 -0
  31. rapidtide/data/examples/src/testnoiseamp +2 -2
  32. rapidtide/data/examples/src/testppgproc +17 -0
  33. rapidtide/data/examples/src/testrefineonly +22 -0
  34. rapidtide/data/examples/src/testretro +26 -13
  35. rapidtide/data/examples/src/testretrolagtcs +16 -0
  36. rapidtide/data/examples/src/testrolloff +11 -0
  37. rapidtide/data/examples/src/testsimdata +45 -28
  38. rapidtide/data/models/model_cnn_pytorch/loss.png +0 -0
  39. rapidtide/data/models/model_cnn_pytorch/loss.txt +1 -0
  40. rapidtide/data/models/model_cnn_pytorch/model.pth +0 -0
  41. rapidtide/data/models/model_cnn_pytorch/model_meta.json +68 -0
  42. rapidtide/data/models/model_cnn_pytorch_fulldata/loss.png +0 -0
  43. rapidtide/data/models/model_cnn_pytorch_fulldata/loss.txt +1 -0
  44. rapidtide/data/models/model_cnn_pytorch_fulldata/model.pth +0 -0
  45. rapidtide/data/models/model_cnn_pytorch_fulldata/model_meta.json +80 -0
  46. rapidtide/data/models/model_cnnbp_pytorch_fullldata/loss.png +0 -0
  47. rapidtide/data/models/model_cnnbp_pytorch_fullldata/loss.txt +1 -0
  48. rapidtide/data/models/model_cnnbp_pytorch_fullldata/model.pth +0 -0
  49. rapidtide/data/models/model_cnnbp_pytorch_fullldata/model_meta.json +138 -0
  50. rapidtide/data/models/model_cnnfft_pytorch_fulldata/loss.png +0 -0
  51. rapidtide/data/models/model_cnnfft_pytorch_fulldata/loss.txt +1 -0
  52. rapidtide/data/models/model_cnnfft_pytorch_fulldata/model.pth +0 -0
  53. rapidtide/data/models/model_cnnfft_pytorch_fulldata/model_meta.json +128 -0
  54. rapidtide/data/models/model_ppgattention_pytorch_w128_fulldata/loss.png +0 -0
  55. rapidtide/data/models/model_ppgattention_pytorch_w128_fulldata/loss.txt +1 -0
  56. rapidtide/data/models/model_ppgattention_pytorch_w128_fulldata/model.pth +0 -0
  57. rapidtide/data/models/model_ppgattention_pytorch_w128_fulldata/model_meta.json +49 -0
  58. rapidtide/data/models/model_revised_tf2/model.keras +0 -0
  59. rapidtide/data/models/{model_serdar → model_revised_tf2}/model_meta.json +1 -1
  60. rapidtide/data/models/model_serdar2_tf2/model.keras +0 -0
  61. rapidtide/data/models/{model_serdar2 → model_serdar2_tf2}/model_meta.json +1 -1
  62. rapidtide/data/models/model_serdar_tf2/model.keras +0 -0
  63. rapidtide/data/models/{model_revised → model_serdar_tf2}/model_meta.json +1 -1
  64. rapidtide/data/reference/HCP1200v2_MTT_2mm.nii.gz +0 -0
  65. rapidtide/data/reference/HCP1200v2_binmask_2mm.nii.gz +0 -0
  66. rapidtide/data/reference/HCP1200v2_csf_2mm.nii.gz +0 -0
  67. rapidtide/data/reference/HCP1200v2_gray_2mm.nii.gz +0 -0
  68. rapidtide/data/reference/HCP1200v2_graylaghist.json +7 -0
  69. rapidtide/data/reference/HCP1200v2_graylaghist.tsv.gz +0 -0
  70. rapidtide/data/reference/HCP1200v2_laghist.json +7 -0
  71. rapidtide/data/reference/HCP1200v2_laghist.tsv.gz +0 -0
  72. rapidtide/data/reference/HCP1200v2_mask_2mm.nii.gz +0 -0
  73. rapidtide/data/reference/HCP1200v2_maxcorr_2mm.nii.gz +0 -0
  74. rapidtide/data/reference/HCP1200v2_maxtime_2mm.nii.gz +0 -0
  75. rapidtide/data/reference/HCP1200v2_maxwidth_2mm.nii.gz +0 -0
  76. rapidtide/data/reference/HCP1200v2_negmask_2mm.nii.gz +0 -0
  77. rapidtide/data/reference/HCP1200v2_timepercentile_2mm.nii.gz +0 -0
  78. rapidtide/data/reference/HCP1200v2_white_2mm.nii.gz +0 -0
  79. rapidtide/data/reference/HCP1200v2_whitelaghist.json +7 -0
  80. rapidtide/data/reference/HCP1200v2_whitelaghist.tsv.gz +0 -0
  81. rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1-seg2.xml +131 -0
  82. rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1-seg2_regions.txt +60 -0
  83. rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1-seg2_space-MNI152NLin6Asym_2mm.nii.gz +0 -0
  84. rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1_space-MNI152NLin2009cAsym_2mm.nii.gz +0 -0
  85. rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1_space-MNI152NLin2009cAsym_2mm_mask.nii.gz +0 -0
  86. rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1_space-MNI152NLin6Asym_2mm_mask.nii.gz +0 -0
  87. rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL2_space-MNI152NLin6Asym_2mm_mask.nii.gz +0 -0
  88. rapidtide/data/reference/MNI152_T1_1mm_Brain_FAST_seg.nii.gz +0 -0
  89. rapidtide/data/reference/MNI152_T1_1mm_Brain_Mask.nii.gz +0 -0
  90. rapidtide/data/reference/MNI152_T1_2mm_Brain_FAST_seg.nii.gz +0 -0
  91. rapidtide/data/reference/MNI152_T1_2mm_Brain_Mask.nii.gz +0 -0
  92. rapidtide/decorators.py +91 -0
  93. rapidtide/dlfilter.py +2553 -414
  94. rapidtide/dlfiltertorch.py +5201 -0
  95. rapidtide/externaltools.py +328 -13
  96. rapidtide/fMRIData_class.py +108 -92
  97. rapidtide/ffttools.py +168 -0
  98. rapidtide/filter.py +2704 -1462
  99. rapidtide/fit.py +2361 -579
  100. rapidtide/genericmultiproc.py +197 -0
  101. rapidtide/happy_supportfuncs.py +3255 -548
  102. rapidtide/helper_classes.py +587 -1116
  103. rapidtide/io.py +2569 -468
  104. rapidtide/linfitfiltpass.py +784 -0
  105. rapidtide/makelaggedtcs.py +267 -97
  106. rapidtide/maskutil.py +555 -25
  107. rapidtide/miscmath.py +835 -144
  108. rapidtide/multiproc.py +217 -44
  109. rapidtide/patchmatch.py +752 -0
  110. rapidtide/peakeval.py +32 -32
  111. rapidtide/ppgproc.py +2205 -0
  112. rapidtide/qualitycheck.py +353 -40
  113. rapidtide/refinedelay.py +854 -0
  114. rapidtide/refineregressor.py +939 -0
  115. rapidtide/resample.py +725 -204
  116. rapidtide/scripts/__init__.py +1 -0
  117. rapidtide/scripts/{adjustoffset → adjustoffset.py} +7 -2
  118. rapidtide/scripts/{aligntcs → aligntcs.py} +7 -2
  119. rapidtide/scripts/{applydlfilter → applydlfilter.py} +7 -2
  120. rapidtide/scripts/applyppgproc.py +28 -0
  121. rapidtide/scripts/{atlasaverage → atlasaverage.py} +7 -2
  122. rapidtide/scripts/{atlastool → atlastool.py} +7 -2
  123. rapidtide/scripts/{calcicc → calcicc.py} +7 -2
  124. rapidtide/scripts/{calctexticc → calctexticc.py} +7 -2
  125. rapidtide/scripts/{calcttest → calcttest.py} +7 -2
  126. rapidtide/scripts/{ccorrica → ccorrica.py} +7 -2
  127. rapidtide/scripts/delayvar.py +28 -0
  128. rapidtide/scripts/{diffrois → diffrois.py} +7 -2
  129. rapidtide/scripts/{endtidalproc → endtidalproc.py} +7 -2
  130. rapidtide/scripts/{fdica → fdica.py} +7 -2
  131. rapidtide/scripts/{filtnifti → filtnifti.py} +7 -2
  132. rapidtide/scripts/{filttc → filttc.py} +7 -2
  133. rapidtide/scripts/{fingerprint → fingerprint.py} +20 -16
  134. rapidtide/scripts/{fixtr → fixtr.py} +7 -2
  135. rapidtide/scripts/{gmscalc → gmscalc.py} +7 -2
  136. rapidtide/scripts/{happy → happy.py} +7 -2
  137. rapidtide/scripts/{happy2std → happy2std.py} +7 -2
  138. rapidtide/scripts/{happywarp → happywarp.py} +8 -4
  139. rapidtide/scripts/{histnifti → histnifti.py} +7 -2
  140. rapidtide/scripts/{histtc → histtc.py} +7 -2
  141. rapidtide/scripts/{glmfilt → linfitfilt.py} +7 -4
  142. rapidtide/scripts/{localflow → localflow.py} +7 -2
  143. rapidtide/scripts/{mergequality → mergequality.py} +7 -2
  144. rapidtide/scripts/{pairproc → pairproc.py} +7 -2
  145. rapidtide/scripts/{pairwisemergenifti → pairwisemergenifti.py} +7 -2
  146. rapidtide/scripts/{physiofreq → physiofreq.py} +7 -2
  147. rapidtide/scripts/{pixelcomp → pixelcomp.py} +7 -2
  148. rapidtide/scripts/{plethquality → plethquality.py} +7 -2
  149. rapidtide/scripts/{polyfitim → polyfitim.py} +7 -2
  150. rapidtide/scripts/{proj2flow → proj2flow.py} +7 -2
  151. rapidtide/scripts/{rankimage → rankimage.py} +7 -2
  152. rapidtide/scripts/{rapidtide → rapidtide.py} +7 -2
  153. rapidtide/scripts/{rapidtide2std → rapidtide2std.py} +7 -2
  154. rapidtide/scripts/{resamplenifti → resamplenifti.py} +7 -2
  155. rapidtide/scripts/{resampletc → resampletc.py} +7 -2
  156. rapidtide/scripts/retrolagtcs.py +28 -0
  157. rapidtide/scripts/retroregress.py +28 -0
  158. rapidtide/scripts/{roisummarize → roisummarize.py} +7 -2
  159. rapidtide/scripts/{runqualitycheck → runqualitycheck.py} +7 -2
  160. rapidtide/scripts/{showarbcorr → showarbcorr.py} +7 -2
  161. rapidtide/scripts/{showhist → showhist.py} +7 -2
  162. rapidtide/scripts/{showstxcorr → showstxcorr.py} +7 -2
  163. rapidtide/scripts/{showtc → showtc.py} +7 -2
  164. rapidtide/scripts/{showxcorr_legacy → showxcorr_legacy.py} +8 -8
  165. rapidtide/scripts/{showxcorrx → showxcorrx.py} +7 -2
  166. rapidtide/scripts/{showxy → showxy.py} +7 -2
  167. rapidtide/scripts/{simdata → simdata.py} +7 -2
  168. rapidtide/scripts/{spatialdecomp → spatialdecomp.py} +7 -2
  169. rapidtide/scripts/{spatialfit → spatialfit.py} +7 -2
  170. rapidtide/scripts/{spatialmi → spatialmi.py} +7 -2
  171. rapidtide/scripts/{spectrogram → spectrogram.py} +7 -2
  172. rapidtide/scripts/stupidramtricks.py +238 -0
  173. rapidtide/scripts/{synthASL → synthASL.py} +7 -2
  174. rapidtide/scripts/{tcfrom2col → tcfrom2col.py} +7 -2
  175. rapidtide/scripts/{tcfrom3col → tcfrom3col.py} +7 -2
  176. rapidtide/scripts/{temporaldecomp → temporaldecomp.py} +7 -2
  177. rapidtide/scripts/{testhrv → testhrv.py} +1 -1
  178. rapidtide/scripts/{threeD → threeD.py} +7 -2
  179. rapidtide/scripts/{tidepool → tidepool.py} +7 -2
  180. rapidtide/scripts/{variabilityizer → variabilityizer.py} +7 -2
  181. rapidtide/simFuncClasses.py +2113 -0
  182. rapidtide/simfuncfit.py +312 -108
  183. rapidtide/stats.py +579 -247
  184. rapidtide/tests/.coveragerc +27 -6
  185. rapidtide-2.9.6.data/scripts/fdica → rapidtide/tests/cleanposttest +4 -6
  186. rapidtide/tests/happycomp +9 -0
  187. rapidtide/tests/resethappytargets +1 -1
  188. rapidtide/tests/resetrapidtidetargets +1 -1
  189. rapidtide/tests/resettargets +1 -1
  190. rapidtide/tests/runlocaltest +3 -3
  191. rapidtide/tests/showkernels +1 -1
  192. rapidtide/tests/test_aliasedcorrelate.py +4 -4
  193. rapidtide/tests/test_aligntcs.py +1 -1
  194. rapidtide/tests/test_calcicc.py +1 -1
  195. rapidtide/tests/test_cleanregressor.py +184 -0
  196. rapidtide/tests/test_congrid.py +70 -81
  197. rapidtide/tests/test_correlate.py +1 -1
  198. rapidtide/tests/test_corrpass.py +4 -4
  199. rapidtide/tests/test_delayestimation.py +54 -59
  200. rapidtide/tests/test_dlfiltertorch.py +437 -0
  201. rapidtide/tests/test_doresample.py +2 -2
  202. rapidtide/tests/test_externaltools.py +69 -0
  203. rapidtide/tests/test_fastresampler.py +9 -5
  204. rapidtide/tests/test_filter.py +96 -57
  205. rapidtide/tests/test_findmaxlag.py +50 -19
  206. rapidtide/tests/test_fullrunhappy_v1.py +15 -10
  207. rapidtide/tests/test_fullrunhappy_v2.py +19 -13
  208. rapidtide/tests/test_fullrunhappy_v3.py +28 -13
  209. rapidtide/tests/test_fullrunhappy_v4.py +30 -11
  210. rapidtide/tests/test_fullrunhappy_v5.py +62 -0
  211. rapidtide/tests/test_fullrunrapidtide_v1.py +61 -7
  212. rapidtide/tests/test_fullrunrapidtide_v2.py +26 -14
  213. rapidtide/tests/test_fullrunrapidtide_v3.py +28 -8
  214. rapidtide/tests/test_fullrunrapidtide_v4.py +16 -8
  215. rapidtide/tests/test_fullrunrapidtide_v5.py +15 -6
  216. rapidtide/tests/test_fullrunrapidtide_v6.py +142 -0
  217. rapidtide/tests/test_fullrunrapidtide_v7.py +114 -0
  218. rapidtide/tests/test_fullrunrapidtide_v8.py +66 -0
  219. rapidtide/tests/test_getparsers.py +158 -0
  220. rapidtide/tests/test_io.py +59 -18
  221. rapidtide/tests/{test_glmpass.py → test_linfitfiltpass.py} +10 -10
  222. rapidtide/tests/test_mi.py +1 -1
  223. rapidtide/tests/test_miscmath.py +1 -1
  224. rapidtide/tests/test_motionregress.py +5 -5
  225. rapidtide/tests/test_nullcorr.py +6 -9
  226. rapidtide/tests/test_padvec.py +216 -0
  227. rapidtide/tests/test_parserfuncs.py +101 -0
  228. rapidtide/tests/test_phaseanalysis.py +1 -1
  229. rapidtide/tests/test_rapidtideparser.py +59 -53
  230. rapidtide/tests/test_refinedelay.py +296 -0
  231. rapidtide/tests/test_runmisc.py +5 -5
  232. rapidtide/tests/test_sharedmem.py +60 -0
  233. rapidtide/tests/test_simroundtrip.py +132 -0
  234. rapidtide/tests/test_simulate.py +1 -1
  235. rapidtide/tests/test_stcorrelate.py +4 -2
  236. rapidtide/tests/test_timeshift.py +2 -2
  237. rapidtide/tests/test_valtoindex.py +1 -1
  238. rapidtide/tests/test_zRapidtideDataset.py +5 -3
  239. rapidtide/tests/utils.py +10 -9
  240. rapidtide/tidepoolTemplate.py +88 -70
  241. rapidtide/tidepoolTemplate.ui +60 -46
  242. rapidtide/tidepoolTemplate_alt.py +88 -53
  243. rapidtide/tidepoolTemplate_alt.ui +62 -52
  244. rapidtide/tidepoolTemplate_alt_qt6.py +921 -0
  245. rapidtide/tidepoolTemplate_big.py +1125 -0
  246. rapidtide/tidepoolTemplate_big.ui +2386 -0
  247. rapidtide/tidepoolTemplate_big_qt6.py +1129 -0
  248. rapidtide/tidepoolTemplate_qt6.py +793 -0
  249. rapidtide/util.py +1389 -148
  250. rapidtide/voxelData.py +1048 -0
  251. rapidtide/wiener.py +138 -25
  252. rapidtide/wiener2.py +114 -8
  253. rapidtide/workflows/adjustoffset.py +107 -5
  254. rapidtide/workflows/aligntcs.py +86 -3
  255. rapidtide/workflows/applydlfilter.py +231 -89
  256. rapidtide/workflows/applyppgproc.py +540 -0
  257. rapidtide/workflows/atlasaverage.py +309 -48
  258. rapidtide/workflows/atlastool.py +130 -9
  259. rapidtide/workflows/calcSimFuncMap.py +490 -0
  260. rapidtide/workflows/calctexticc.py +202 -10
  261. rapidtide/workflows/ccorrica.py +123 -15
  262. rapidtide/workflows/cleanregressor.py +415 -0
  263. rapidtide/workflows/delayvar.py +1268 -0
  264. rapidtide/workflows/diffrois.py +84 -6
  265. rapidtide/workflows/endtidalproc.py +149 -9
  266. rapidtide/workflows/fdica.py +197 -17
  267. rapidtide/workflows/filtnifti.py +71 -4
  268. rapidtide/workflows/filttc.py +76 -5
  269. rapidtide/workflows/fitSimFuncMap.py +578 -0
  270. rapidtide/workflows/fixtr.py +74 -4
  271. rapidtide/workflows/gmscalc.py +116 -6
  272. rapidtide/workflows/happy.py +1242 -480
  273. rapidtide/workflows/happy2std.py +145 -13
  274. rapidtide/workflows/happy_parser.py +277 -59
  275. rapidtide/workflows/histnifti.py +120 -4
  276. rapidtide/workflows/histtc.py +85 -4
  277. rapidtide/workflows/{glmfilt.py → linfitfilt.py} +128 -14
  278. rapidtide/workflows/localflow.py +329 -29
  279. rapidtide/workflows/mergequality.py +80 -4
  280. rapidtide/workflows/niftidecomp.py +323 -19
  281. rapidtide/workflows/niftistats.py +178 -8
  282. rapidtide/workflows/pairproc.py +99 -5
  283. rapidtide/workflows/pairwisemergenifti.py +86 -3
  284. rapidtide/workflows/parser_funcs.py +1488 -56
  285. rapidtide/workflows/physiofreq.py +139 -12
  286. rapidtide/workflows/pixelcomp.py +211 -9
  287. rapidtide/workflows/plethquality.py +105 -23
  288. rapidtide/workflows/polyfitim.py +159 -19
  289. rapidtide/workflows/proj2flow.py +76 -3
  290. rapidtide/workflows/rankimage.py +115 -8
  291. rapidtide/workflows/rapidtide.py +1785 -1858
  292. rapidtide/workflows/rapidtide2std.py +101 -3
  293. rapidtide/workflows/rapidtide_parser.py +590 -389
  294. rapidtide/workflows/refineDelayMap.py +249 -0
  295. rapidtide/workflows/refineRegressor.py +1215 -0
  296. rapidtide/workflows/regressfrommaps.py +308 -0
  297. rapidtide/workflows/resamplenifti.py +86 -4
  298. rapidtide/workflows/resampletc.py +92 -4
  299. rapidtide/workflows/retrolagtcs.py +442 -0
  300. rapidtide/workflows/retroregress.py +1501 -0
  301. rapidtide/workflows/roisummarize.py +176 -7
  302. rapidtide/workflows/runqualitycheck.py +72 -7
  303. rapidtide/workflows/showarbcorr.py +172 -16
  304. rapidtide/workflows/showhist.py +87 -3
  305. rapidtide/workflows/showstxcorr.py +161 -4
  306. rapidtide/workflows/showtc.py +172 -10
  307. rapidtide/workflows/showxcorrx.py +250 -62
  308. rapidtide/workflows/showxy.py +186 -16
  309. rapidtide/workflows/simdata.py +418 -112
  310. rapidtide/workflows/spatialfit.py +83 -8
  311. rapidtide/workflows/spatialmi.py +252 -29
  312. rapidtide/workflows/spectrogram.py +306 -33
  313. rapidtide/workflows/synthASL.py +157 -6
  314. rapidtide/workflows/tcfrom2col.py +77 -3
  315. rapidtide/workflows/tcfrom3col.py +75 -3
  316. rapidtide/workflows/tidepool.py +3829 -666
  317. rapidtide/workflows/utils.py +45 -19
  318. rapidtide/workflows/utils_doc.py +293 -0
  319. rapidtide/workflows/variabilityizer.py +118 -5
  320. {rapidtide-2.9.6.dist-info → rapidtide-3.1.3.dist-info}/METADATA +30 -223
  321. rapidtide-3.1.3.dist-info/RECORD +393 -0
  322. {rapidtide-2.9.6.dist-info → rapidtide-3.1.3.dist-info}/WHEEL +1 -1
  323. rapidtide-3.1.3.dist-info/entry_points.txt +65 -0
  324. rapidtide-3.1.3.dist-info/top_level.txt +2 -0
  325. rapidtide/calcandfitcorrpairs.py +0 -262
  326. rapidtide/data/examples/src/testoutputsize +0 -45
  327. rapidtide/data/models/model_revised/model.h5 +0 -0
  328. rapidtide/data/models/model_serdar/model.h5 +0 -0
  329. rapidtide/data/models/model_serdar2/model.h5 +0 -0
  330. rapidtide/data/reference/ASPECTS_nlin_asym_09c_2mm.nii.gz +0 -0
  331. rapidtide/data/reference/ASPECTS_nlin_asym_09c_2mm_mask.nii.gz +0 -0
  332. rapidtide/data/reference/ATTbasedFlowTerritories_split_nlin_asym_09c_2mm.nii.gz +0 -0
  333. rapidtide/data/reference/ATTbasedFlowTerritories_split_nlin_asym_09c_2mm_mask.nii.gz +0 -0
  334. rapidtide/data/reference/HCP1200_binmask_2mm_2009c_asym.nii.gz +0 -0
  335. rapidtide/data/reference/HCP1200_lag_2mm_2009c_asym.nii.gz +0 -0
  336. rapidtide/data/reference/HCP1200_mask_2mm_2009c_asym.nii.gz +0 -0
  337. rapidtide/data/reference/HCP1200_negmask_2mm_2009c_asym.nii.gz +0 -0
  338. rapidtide/data/reference/HCP1200_sigma_2mm_2009c_asym.nii.gz +0 -0
  339. rapidtide/data/reference/HCP1200_strength_2mm_2009c_asym.nii.gz +0 -0
  340. rapidtide/glmpass.py +0 -434
  341. rapidtide/refine_factored.py +0 -641
  342. rapidtide/scripts/retroglm +0 -23
  343. rapidtide/workflows/glmfrommaps.py +0 -202
  344. rapidtide/workflows/retroglm.py +0 -643
  345. rapidtide-2.9.6.data/scripts/adjustoffset +0 -23
  346. rapidtide-2.9.6.data/scripts/aligntcs +0 -23
  347. rapidtide-2.9.6.data/scripts/applydlfilter +0 -23
  348. rapidtide-2.9.6.data/scripts/atlasaverage +0 -23
  349. rapidtide-2.9.6.data/scripts/atlastool +0 -23
  350. rapidtide-2.9.6.data/scripts/calcicc +0 -22
  351. rapidtide-2.9.6.data/scripts/calctexticc +0 -23
  352. rapidtide-2.9.6.data/scripts/calcttest +0 -22
  353. rapidtide-2.9.6.data/scripts/ccorrica +0 -23
  354. rapidtide-2.9.6.data/scripts/diffrois +0 -23
  355. rapidtide-2.9.6.data/scripts/endtidalproc +0 -23
  356. rapidtide-2.9.6.data/scripts/filtnifti +0 -23
  357. rapidtide-2.9.6.data/scripts/filttc +0 -23
  358. rapidtide-2.9.6.data/scripts/fingerprint +0 -593
  359. rapidtide-2.9.6.data/scripts/fixtr +0 -23
  360. rapidtide-2.9.6.data/scripts/glmfilt +0 -24
  361. rapidtide-2.9.6.data/scripts/gmscalc +0 -22
  362. rapidtide-2.9.6.data/scripts/happy +0 -25
  363. rapidtide-2.9.6.data/scripts/happy2std +0 -23
  364. rapidtide-2.9.6.data/scripts/happywarp +0 -350
  365. rapidtide-2.9.6.data/scripts/histnifti +0 -23
  366. rapidtide-2.9.6.data/scripts/histtc +0 -23
  367. rapidtide-2.9.6.data/scripts/localflow +0 -23
  368. rapidtide-2.9.6.data/scripts/mergequality +0 -23
  369. rapidtide-2.9.6.data/scripts/pairproc +0 -23
  370. rapidtide-2.9.6.data/scripts/pairwisemergenifti +0 -23
  371. rapidtide-2.9.6.data/scripts/physiofreq +0 -23
  372. rapidtide-2.9.6.data/scripts/pixelcomp +0 -23
  373. rapidtide-2.9.6.data/scripts/plethquality +0 -23
  374. rapidtide-2.9.6.data/scripts/polyfitim +0 -23
  375. rapidtide-2.9.6.data/scripts/proj2flow +0 -23
  376. rapidtide-2.9.6.data/scripts/rankimage +0 -23
  377. rapidtide-2.9.6.data/scripts/rapidtide +0 -23
  378. rapidtide-2.9.6.data/scripts/rapidtide2std +0 -23
  379. rapidtide-2.9.6.data/scripts/resamplenifti +0 -23
  380. rapidtide-2.9.6.data/scripts/resampletc +0 -23
  381. rapidtide-2.9.6.data/scripts/retroglm +0 -23
  382. rapidtide-2.9.6.data/scripts/roisummarize +0 -23
  383. rapidtide-2.9.6.data/scripts/runqualitycheck +0 -23
  384. rapidtide-2.9.6.data/scripts/showarbcorr +0 -23
  385. rapidtide-2.9.6.data/scripts/showhist +0 -23
  386. rapidtide-2.9.6.data/scripts/showstxcorr +0 -23
  387. rapidtide-2.9.6.data/scripts/showtc +0 -23
  388. rapidtide-2.9.6.data/scripts/showxcorr_legacy +0 -536
  389. rapidtide-2.9.6.data/scripts/showxcorrx +0 -23
  390. rapidtide-2.9.6.data/scripts/showxy +0 -23
  391. rapidtide-2.9.6.data/scripts/simdata +0 -23
  392. rapidtide-2.9.6.data/scripts/spatialdecomp +0 -23
  393. rapidtide-2.9.6.data/scripts/spatialfit +0 -23
  394. rapidtide-2.9.6.data/scripts/spatialmi +0 -23
  395. rapidtide-2.9.6.data/scripts/spectrogram +0 -23
  396. rapidtide-2.9.6.data/scripts/synthASL +0 -23
  397. rapidtide-2.9.6.data/scripts/tcfrom2col +0 -23
  398. rapidtide-2.9.6.data/scripts/tcfrom3col +0 -23
  399. rapidtide-2.9.6.data/scripts/temporaldecomp +0 -23
  400. rapidtide-2.9.6.data/scripts/threeD +0 -236
  401. rapidtide-2.9.6.data/scripts/tidepool +0 -23
  402. rapidtide-2.9.6.data/scripts/variabilityizer +0 -23
  403. rapidtide-2.9.6.dist-info/RECORD +0 -359
  404. rapidtide-2.9.6.dist-info/top_level.txt +0 -86
  405. {rapidtide-2.9.6.dist-info → rapidtide-3.1.3.dist-info/licenses}/LICENSE +0 -0
rapidtide/correlate.py CHANGED
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8 -*-
3
3
  #
4
- # Copyright 2016-2024 Blaise Frederick
4
+ # Copyright 2016-2025 Blaise Frederick
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -19,9 +19,13 @@
19
19
  """Functions for calculating correlations and similar metrics between arrays."""
20
20
  import logging
21
21
  import warnings
22
+ from typing import Any, Callable, Optional, Tuple, Union
22
23
 
23
24
  import matplotlib.pyplot as plt
24
25
  import numpy as np
26
+ from numpy.typing import NDArray
27
+
28
+ from rapidtide.ffttools import optfftlen
25
29
 
26
30
  with warnings.catch_warnings():
27
31
  warnings.simplefilter("ignore")
@@ -38,12 +42,12 @@ from numpy.fft import irfftn, rfftn
38
42
  from scipy import fftpack, signal
39
43
  from sklearn.metrics import mutual_info_score
40
44
 
41
- import rapidtide.correlate as tide_corr
42
45
  import rapidtide.fit as tide_fit
43
46
  import rapidtide.miscmath as tide_math
44
47
  import rapidtide.resample as tide_resample
45
48
  import rapidtide.stats as tide_stats
46
49
  import rapidtide.util as tide_util
50
+ from rapidtide.decorators import conditionaljit
47
51
 
48
52
  if pyfftwpresent:
49
53
  fftpack = pyfftw.interfaces.scipy_fftpack
@@ -56,64 +60,80 @@ MAXLINES = 10000000
56
60
  donotbeaggressive = True
57
61
 
58
62
 
59
- # ----------------------------------------- Conditional imports ---------------------------------------
60
- try:
61
- from numba import jit
62
- except ImportError:
63
- donotusenumba = True
64
- else:
65
- donotusenumba = False
66
-
67
-
68
- def conditionaljit():
69
- """Wrap functions in jit if numba is enabled."""
70
-
71
- def resdec(f):
72
- if donotusenumba:
73
- return f
74
- return jit(f, nopython=True)
75
-
76
- return resdec
77
-
78
-
79
- def disablenumba():
80
- """Set a global variable to disable numba."""
81
- global donotusenumba
82
- donotusenumba = True
83
-
84
-
85
63
  # --------------------------- Correlation functions -------------------------------------------------
86
64
  def check_autocorrelation(
87
- corrscale,
88
- thexcorr,
89
- delta=0.1,
90
- acampthresh=0.1,
91
- aclagthresh=10.0,
92
- displayplots=False,
93
- detrendorder=1,
94
- ):
95
- """Check for autocorrelation in an array.
65
+ corrscale: NDArray,
66
+ thexcorr: NDArray,
67
+ delta: float = 0.05,
68
+ acampthresh: float = 0.1,
69
+ aclagthresh: float = 10.0,
70
+ displayplots: bool = False,
71
+ detrendorder: int = 1,
72
+ debug: bool = False,
73
+ ) -> Tuple[Optional[float], Optional[float]]:
74
+ """
75
+ Check for autocorrelation peaks in a cross-correlation signal and fit a Gaussian to the sidelobe.
76
+
77
+ This function identifies peaks in the cross-correlation signal and, if a significant
78
+ sidelobe is detected (based on amplitude and lag thresholds), fits a Gaussian function
79
+ to estimate the sidelobe's time and amplitude.
96
80
 
97
81
  Parameters
98
82
  ----------
99
- corrscale
100
- thexcorr
101
- delta
102
- acampthresh
103
- aclagthresh
104
- displayplots
105
- windowfunc
106
- detrendorder
83
+ corrscale : NDArray
84
+ Array of time lags corresponding to the cross-correlation values.
85
+ thexcorr : NDArray
86
+ Array of cross-correlation values.
87
+ delta : float, optional
88
+ Minimum distance between peaks, default is 0.05.
89
+ acampthresh : float, optional
90
+ Amplitude threshold for detecting sidelobes, default is 0.1.
91
+ aclagthresh : float, optional
92
+ Lag threshold beyond which sidelobes are ignored, default is 10.0.
93
+ displayplots : bool, optional
94
+ If True, display the cross-correlation plot with detected peaks, default is False.
95
+ detrendorder : int, optional
96
+ Order of detrending to apply to the signal, default is 1.
97
+ debug : bool, optional
98
+ If True, print debug information, default is False.
107
99
 
108
100
  Returns
109
101
  -------
110
- sidelobetime
111
- sidelobeamp
102
+ Tuple[Optional[float], Optional[float]]
103
+ A tuple containing the estimated sidelobe time and amplitude if a valid sidelobe is found,
104
+ otherwise (None, None).
105
+
106
+ Notes
107
+ -----
108
+ - The function uses `peakdetect` to find peaks in the cross-correlation.
109
+ - A Gaussian fit is performed only if a peak is found beyond the zero-lag point and
110
+ satisfies the amplitude and lag thresholds.
111
+ - The fit is performed on a window around the detected sidelobe.
112
+
113
+ Examples
114
+ --------
115
+ >>> corrscale = np.linspace(0, 20, 100)
116
+ >>> thexcorr = np.exp(-0.5 * (corrscale - 5)**2 / 2) + 0.1 * np.random.rand(100)
117
+ >>> time, amp = check_autocorrelation(corrscale, thexcorr, delta=0.1, acampthresh=0.05)
118
+ >>> print(f"Sidelobe time: {time}, Amplitude: {amp}")
112
119
  """
120
+ if debug:
121
+ print("check_autocorrelation:")
122
+ print(f"delta: {delta}")
123
+ print(f"acampthresh: {acampthresh}")
124
+ print(f"aclagthresh: {aclagthresh}")
125
+ print(f"displayplots: {displayplots}")
113
126
  lookahead = 2
127
+ if displayplots:
128
+ print(f"check_autocorrelation: {displayplots=}")
129
+ plt.plot(corrscale, thexcorr)
130
+ plt.show()
114
131
  peaks = tide_fit.peakdetect(thexcorr, x_axis=corrscale, delta=delta, lookahead=lookahead)
115
132
  maxpeaks = np.asarray(peaks[0], dtype="float64")
116
133
  if len(peaks[0]) > 0:
134
+ if debug:
135
+ print(f"found {len(peaks[0])} peaks")
136
+ print(peaks)
117
137
  LGR.debug(peaks)
118
138
  zeropkindex = np.argmin(abs(maxpeaks[:, 0]))
119
139
  for i in range(zeropkindex + 1, maxpeaks.shape[0]):
@@ -155,35 +175,68 @@ def check_autocorrelation(
155
175
  )
156
176
  plt.show()
157
177
  return sidelobetime, sidelobeamp
178
+ else:
179
+ if debug:
180
+ print("no peaks found")
158
181
  return None, None
159
182
 
160
183
 
161
184
  def shorttermcorr_1D(
162
- data1,
163
- data2,
164
- sampletime,
165
- windowtime,
166
- samplestep=1,
167
- detrendorder=0,
168
- windowfunc="hamming",
169
- ):
170
- """Calculate short-term sliding-window correlation between two 1D arrays.
185
+ data1: NDArray,
186
+ data2: NDArray,
187
+ sampletime: float,
188
+ windowtime: float,
189
+ samplestep: int = 1,
190
+ detrendorder: int = 0,
191
+ windowfunc: str = "hamming",
192
+ ) -> Tuple[NDArray, NDArray, NDArray]:
193
+ """
194
+ Compute short-term cross-correlation between two 1D signals using sliding windows.
195
+
196
+ This function calculates the Pearson correlation coefficient between two signals
197
+ over short time windows, allowing for the analysis of time-varying correlations.
198
+ The correlation is computed for overlapping windows across the input data,
199
+ with optional detrending and windowing applied to each segment.
171
200
 
172
201
  Parameters
173
202
  ----------
174
- data1
175
- data2
176
- sampletime
177
- windowtime
178
- samplestep
179
- detrendorder
180
- windowfunc
203
+ data1 : NDArray
204
+ First input signal (1D array).
205
+ data2 : NDArray
206
+ Second input signal (1D array). Must have the same length as `data1`.
207
+ sampletime : float
208
+ Time interval between consecutive samples in seconds.
209
+ windowtime : float
210
+ Length of the sliding window in seconds.
211
+ samplestep : int, optional
212
+ Step size (in samples) between consecutive windows. Default is 1.
213
+ detrendorder : int, optional
214
+ Order of detrending to apply before correlation. 0 means no detrending.
215
+ Default is 0.
216
+ windowfunc : str, optional
217
+ Window function to apply to each segment. Default is "hamming".
181
218
 
182
219
  Returns
183
220
  -------
184
- times
185
- corrpertime
186
- ppertime
221
+ times : NDArray
222
+ Array of time values corresponding to the center of each window.
223
+ corrpertime : NDArray
224
+ Array of Pearson correlation coefficients for each window.
225
+ ppertime : NDArray
226
+ Array of p-values associated with the correlation coefficients.
227
+
228
+ Notes
229
+ -----
230
+ The function uses `tide_math.corrnormalize` for normalization and detrending
231
+ of signal segments, and `scipy.stats.pearsonr` for computing the correlation.
232
+
233
+ Examples
234
+ --------
235
+ >>> import numpy as np
236
+ >>> data1 = np.random.randn(1000)
237
+ >>> data2 = np.random.randn(1000)
238
+ >>> times, corr, pvals = shorttermcorr_1D(data1, data2, 0.1, 1.0)
239
+ >>> print(f"Correlation at time {times[0]:.2f}: {corr[0]:.3f}")
187
240
  """
188
241
  windowsize = int(windowtime // sampletime)
189
242
  halfwindow = int((windowsize + 1) // 2)
@@ -201,10 +254,11 @@ def shorttermcorr_1D(
201
254
  detrendorder=detrendorder,
202
255
  windowfunc=windowfunc,
203
256
  )
204
- thepcorr = sp.stats.pearsonr(dataseg1, dataseg2)
257
+ thepearsonresult = sp.stats.pearsonr(dataseg1, dataseg2)
258
+ thepcorrR, thepcorrp = thepearsonresult.statistic, thepearsonresult.pvalue
205
259
  times.append(i * sampletime)
206
- corrpertime.append(thepcorr[0])
207
- ppertime.append(thepcorr[1])
260
+ corrpertime.append(thepcorrR)
261
+ ppertime.append(thepcorrp)
208
262
  return (
209
263
  np.asarray(times, dtype="float64"),
210
264
  np.asarray(corrpertime, dtype="float64"),
@@ -213,42 +267,85 @@ def shorttermcorr_1D(
213
267
 
214
268
 
215
269
  def shorttermcorr_2D(
216
- data1,
217
- data2,
218
- sampletime,
219
- windowtime,
220
- samplestep=1,
221
- laglimits=None,
222
- weighting="None",
223
- zeropadding=0,
224
- windowfunc="None",
225
- detrendorder=0,
226
- compress=False,
227
- displayplots=False,
228
- ):
229
- """Calculate short-term sliding-window correlation between two 2D arrays.
270
+ data1: NDArray,
271
+ data2: NDArray,
272
+ sampletime: float,
273
+ windowtime: float,
274
+ samplestep: int = 1,
275
+ laglimits: Optional[Tuple[float, float]] = None,
276
+ weighting: str = "None",
277
+ zeropadding: int = 0,
278
+ windowfunc: str = "None",
279
+ detrendorder: int = 0,
280
+ compress: bool = False,
281
+ displayplots: bool = False,
282
+ ) -> Tuple[NDArray, NDArray, NDArray, NDArray, NDArray]:
283
+ """
284
+ Compute short-term cross-correlations between two 1D signals over sliding windows.
285
+
286
+ This function computes the cross-correlation between two input signals (`data1` and `data2`)
287
+ using a sliding window approach. For each window, the cross-correlation is computed and
288
+ the peak lag and correlation coefficient are extracted. The function supports detrending,
289
+ windowing, and various correlation weighting schemes.
230
290
 
231
291
  Parameters
232
292
  ----------
233
- data1
234
- data2
235
- sampletime
236
- windowtime
237
- samplestep
238
- laglimits
239
- weighting
240
- zeropadding
241
- windowfunc
242
- detrendorder
243
- displayplots
293
+ data1 : NDArray
294
+ First input signal (1D array).
295
+ data2 : NDArray
296
+ Second input signal (1D array). Must be of the same length as `data1`.
297
+ sampletime : float
298
+ Sampling interval of the input signals in seconds.
299
+ windowtime : float
300
+ Length of the sliding window in seconds.
301
+ samplestep : int, optional
302
+ Step size (in samples) for the sliding window. Default is 1.
303
+ laglimits : Tuple[float, float], optional
304
+ Minimum and maximum lag limits (in seconds) for peak detection.
305
+ If None, defaults to ±windowtime/2.
306
+ weighting : str, optional
307
+ Type of weighting to apply during cross-correlation ('None', 'hamming', etc.).
308
+ Default is 'None'.
309
+ zeropadding : int, optional
310
+ Zero-padding factor for the FFT-based correlation. Default is 0.
311
+ windowfunc : str, optional
312
+ Type of window function to apply ('None', 'hamming', etc.). Default is 'None'.
313
+ detrendorder : int, optional
314
+ Order of detrending to apply before correlation (0 = no detrend, 1 = linear, etc.).
315
+ Default is 0.
316
+ compress : bool, optional
317
+ Whether to compress the correlation result. Default is False.
318
+ displayplots : bool, optional
319
+ Whether to display intermediate plots (e.g., correlation matrix). Default is False.
244
320
 
245
321
  Returns
246
322
  -------
247
- times
248
- xcorrpertime
249
- Rvals
250
- delayvals
251
- valid
323
+ times : NDArray
324
+ Array of time values corresponding to the center of each window.
325
+ xcorrpertime : NDArray
326
+ Array of cross-correlation functions for each window.
327
+ Rvals : NDArray
328
+ Correlation coefficients for each window.
329
+ delayvals : NDArray
330
+ Estimated time delays (lags) for each window.
331
+ valid : NDArray
332
+ Binary array indicating whether the peak detection was successful (1) or failed (0).
333
+
334
+ Notes
335
+ -----
336
+ - The function uses `fastcorrelate` for efficient cross-correlation computation.
337
+ - Peak detection is performed using `tide_fit.findmaxlag_gauss`.
338
+ - If `displayplots` is True, an image of the cross-correlations is shown.
339
+
340
+ Examples
341
+ --------
342
+ >>> import numpy as np
343
+ >>> t = np.linspace(0, 10, 1000)
344
+ >>> signal1 = np.sin(2 * np.pi * 0.5 * t)
345
+ >>> signal2 = np.sin(2 * np.pi * 0.5 * t + 0.1)
346
+ >>> times, xcorrs, Rvals, delays, valid = shorttermcorr_2D(
347
+ ... signal1, signal2, sampletime=0.01, windowtime=1.0
348
+ ... )
252
349
  """
253
350
  windowsize = int(windowtime // sampletime)
254
351
  halfwindow = int((windowsize + 1) // 2)
@@ -339,74 +436,230 @@ def shorttermcorr_2D(
339
436
  )
340
437
 
341
438
 
342
- def calc_MI(x, y, bins=50):
343
- """Calculate mutual information between two arrays.
439
+ def calc_MI(x: NDArray, y: NDArray, bins: int = 50) -> float:
440
+ """
441
+ Calculate mutual information between two arrays.
442
+
443
+ Parameters
444
+ ----------
445
+ x : array-like
446
+ First array of data points
447
+ y : array-like
448
+ Second array of data points
449
+ bins : int, optional
450
+ Number of bins to use for histogram estimation, default is 50
451
+
452
+ Returns
453
+ -------
454
+ float
455
+ Mutual information between x and y
344
456
 
345
457
  Notes
346
458
  -----
347
- From https://stackoverflow.com/questions/20491028/
348
- optimal-way-to-compute-pairwise-mutual-information-using-numpy/
459
+ This implementation uses 2D histogram estimation followed by mutual information
460
+ calculation. The method is based on the approach from:
461
+ https://stackoverflow.com/questions/20491028/optimal-way-to-compute-pairwise-mutual-information-using-numpy/
349
462
  20505476#20505476
463
+
464
+ Examples
465
+ --------
466
+ >>> import numpy as np
467
+ >>> x = np.random.randn(1000)
468
+ >>> y = x + np.random.randn(1000) * 0.5
469
+ >>> mi = calc_MI(x, y)
470
+ >>> print(f"Mutual information: {mi:.3f}")
350
471
  """
351
472
  c_xy = np.histogram2d(x, y, bins)[0]
352
473
  mi = mutual_info_score(None, None, contingency=c_xy)
353
474
  return mi
354
475
 
355
476
 
477
+ # @conditionaljit()
478
+ def mutual_info_2d_fast(
479
+ x: NDArray[np.floating[Any]],
480
+ y: NDArray[np.floating[Any]],
481
+ bins: Tuple[NDArray, NDArray],
482
+ sigma: float = 1,
483
+ normalized: bool = True,
484
+ EPS: float = 1.0e-6,
485
+ debug: bool = False,
486
+ ) -> float:
487
+ """
488
+ Compute (normalized) mutual information between two 1D variates from a joint histogram.
489
+
490
+ Parameters
491
+ ----------
492
+ x : 1D NDArray[np.floating[Any]]
493
+ First variable.
494
+ y : 1D NDArray[np.floating[Any]]
495
+ Second variable.
496
+ bins : tuple of NDArray
497
+ Bin edges for the histogram. The first element corresponds to `x` and the second to `y`.
498
+ sigma : float, optional
499
+ Sigma for Gaussian smoothing of the joint histogram. Default is 1.
500
+ normalized : bool, optional
501
+ If True, compute normalized mutual information as defined in [1]_. Default is True.
502
+ EPS : float, optional
503
+ Small constant to avoid numerical errors in logarithms. Default is 1e-6.
504
+ debug : bool, optional
505
+ If True, print intermediate values for debugging. Default is False.
506
+
507
+ Returns
508
+ -------
509
+ float
510
+ The computed mutual information (or normalized mutual information if `normalized=True`).
511
+
512
+ Notes
513
+ -----
514
+ This function computes mutual information using a 2D histogram and Gaussian smoothing.
515
+ The normalization follows the approach described in [1]_.
516
+
517
+ References
518
+ ----------
519
+ .. [1] Colin Studholme, David John Hawkes, Derek L.G. Hill (1998).
520
+ "Normalized entropy measure for multimodality image alignment".
521
+ in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
522
+
523
+ Examples
524
+ --------
525
+ >>> import numpy as np
526
+ >>> x = np.random.randn(1000)
527
+ >>> y = np.random.randn(1000)
528
+ >>> bins = (np.linspace(-3, 3, 64), np.linspace(-3, 3, 64))
529
+ >>> mi = mutual_info_2d_fast(x, y, bins)
530
+ >>> print(mi)
531
+ """
532
+ xstart = bins[0][0]
533
+ xend = bins[0][-1]
534
+ ystart = bins[1][0]
535
+ yend = bins[1][-1]
536
+ numxbins = int(len(bins[0]) - 1)
537
+ numybins = int(len(bins[1]) - 1)
538
+ cuts = (x >= xstart) & (x < xend) & (y >= ystart) & (y < yend)
539
+ c = ((x[cuts] - xstart) / (xend - xstart) * numxbins).astype(np.int_)
540
+ c += ((y[cuts] - ystart) / (yend - ystart) * numybins).astype(np.int_) * numxbins
541
+ jh = np.bincount(c, minlength=numxbins * numybins).reshape(numxbins, numybins)
542
+
543
+ return proc_MI_histogram(jh, sigma=sigma, normalized=normalized, EPS=EPS, debug=debug)
544
+
545
+
356
546
  # @conditionaljit()
357
547
  def mutual_info_2d(
358
- x, y, sigma=1, bins=(256, 256), fast=False, normalized=True, EPS=1.0e-6, debug=False
359
- ):
360
- """Compute (normalized) mutual information between two 1D variate from a joint histogram.
548
+ x: NDArray[np.floating[Any]],
549
+ y: NDArray[np.floating[Any]],
550
+ bins: Tuple[int, int],
551
+ sigma: float = 1,
552
+ normalized: bool = True,
553
+ EPS: float = 1.0e-6,
554
+ debug: bool = False,
555
+ ) -> float:
556
+ """
557
+ Compute (normalized) mutual information between two 1D variates from a joint histogram.
361
558
 
362
559
  Parameters
363
560
  ----------
364
- x : 1D array
365
- first variable
366
- y : 1D array
367
- second variable
561
+ x : 1D NDArray[np.floating[Any]]
562
+ First variable.
563
+ y : 1D NDArray[np.floating[Any]]
564
+ Second variable.
565
+ bins : tuple of int
566
+ Number of bins for the histogram. The first element is the number of bins for `x`
567
+ and the second for `y`.
368
568
  sigma : float, optional
369
- Sigma for Gaussian smoothing of the joint histogram.
370
- Default = 1.
371
- bins : tuple, optional
372
- fast : bool, optional
373
- normalized : bool
374
- If True, this will calculate the normalized mutual information from [1]_.
375
- Default = False.
569
+ Sigma for Gaussian smoothing of the joint histogram. Default is 1.
570
+ normalized : bool, optional
571
+ If True, compute normalized mutual information as defined in [1]_. Default is True.
376
572
  EPS : float, optional
377
- Default = 1.0e-6.
573
+ Small constant to avoid numerical errors in logarithms. Default is 1e-6.
574
+ debug : bool, optional
575
+ If True, print intermediate values for debugging. Default is False.
378
576
 
379
577
  Returns
380
578
  -------
381
- nmi: float
382
- the computed similarity measure
579
+ float
580
+ The computed mutual information (or normalized mutual information if `normalized=True`).
383
581
 
384
582
  Notes
385
583
  -----
386
- From Ionnis Pappas
387
- BBF added the precaching (fast) option
584
+ This function computes mutual information using a 2D histogram and Gaussian smoothing.
585
+ The normalization follows the approach described in [1]_.
388
586
 
389
587
  References
390
588
  ----------
391
589
  .. [1] Colin Studholme, David John Hawkes, Derek L.G. Hill (1998).
392
590
  "Normalized entropy measure for multimodality image alignment".
393
591
  in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
592
+
593
+ Examples
594
+ --------
595
+ >>> import numpy as np
596
+ >>> x = np.random.randn(1000)
597
+ >>> y = np.random.randn(1000)
598
+ >>> mi = mutual_info_2d(x, y)
599
+ >>> print(mi)
600
+ """
601
+ jh, xbins, ybins = np.histogram2d(x, y, bins=bins)
602
+ if debug:
603
+ print(f"{xbins} {ybins}")
604
+
605
+ return proc_MI_histogram(jh, sigma=sigma, normalized=normalized, EPS=EPS, debug=debug)
606
+
607
+
608
+ def proc_MI_histogram(
609
+ jh: NDArray[np.floating[Any]],
610
+ sigma: float = 1,
611
+ normalized: bool = True,
612
+ EPS: float = 1.0e-6,
613
+ debug: bool = False,
614
+ ) -> float:
615
+ """
616
+ Compute the mutual information (MI) between two variables from a joint histogram.
617
+
618
+ This function calculates mutual information using the joint histogram of two variables,
619
+ applying Gaussian smoothing and computing entropy-based MI. It supports both normalized
620
+ and unnormalized versions of the mutual information.
621
+
622
+ Parameters
623
+ ----------
624
+ jh : ndarray of shape (m, n)
625
+ Joint histogram of two variables. Should be a 2D array of floating point values.
626
+ sigma : float, optional
627
+ Standard deviation for Gaussian smoothing of the joint histogram. Default is 1.0.
628
+ normalized : bool, optional
629
+ If True, returns normalized mutual information. If False, returns unnormalized
630
+ mutual information. Default is True.
631
+ EPS : float, optional
632
+ Small constant added to the histogram to avoid numerical issues in log computation.
633
+ Default is 1e-6.
634
+ debug : bool, optional
635
+ If True, prints intermediate values for debugging purposes. Default is False.
636
+
637
+ Returns
638
+ -------
639
+ float
640
+ The computed mutual information (MI) between the two variables. The value is
641
+ positive and indicates the amount of information shared between the variables.
642
+
643
+ Notes
644
+ -----
645
+ The function applies Gaussian smoothing to the joint histogram before computing
646
+ marginal and joint entropies. The mutual information is computed as:
647
+
648
+ .. math::
649
+ MI = \\frac{H(X) + H(Y)}{H(X,Y)} - 1
650
+
651
+ where :math:`H(X)`, :math:`H(Y)`, and :math:`H(X,Y)` are the marginal and joint entropies,
652
+ respectively. If `normalized=False`, the unnormalized MI is returned instead.
653
+
654
+ Examples
655
+ --------
656
+ >>> import numpy as np
657
+ >>> from scipy import ndimage
658
+ >>> jh = np.random.rand(10, 10)
659
+ >>> mi = proc_MI_histogram(jh, sigma=0.5, normalized=True)
660
+ >>> print(mi)
661
+ 0.123456789
394
662
  """
395
- if fast:
396
- xstart = bins[0][0]
397
- xend = bins[0][-1]
398
- ystart = bins[1][0]
399
- yend = bins[1][-1]
400
- numxbins = len(bins[0]) - 1
401
- numybins = len(bins[1]) - 1
402
- cuts = (x >= xstart) & (x < xend) & (y >= ystart) & (y < yend)
403
- c = ((x[cuts] - xstart) / (xend - xstart) * numxbins).astype(np.int_)
404
- c += ((y[cuts] - ystart) / (yend - ystart) * numybins).astype(np.int_) * numxbins
405
- jh = np.bincount(c, minlength=numxbins * numybins).reshape(numxbins, numybins)
406
- else:
407
- jh, xbins, ybins = np.histogram2d(x, y, bins=bins)
408
- if debug:
409
- print(f"{xbins} {ybins}")
410
663
 
411
664
  # smooth the jh with a gaussian filter of given sigma
412
665
  sp.ndimage.gaussian_filter(jh, sigma=sigma, mode="constant", output=jh)
@@ -441,65 +694,96 @@ def mutual_info_2d(
441
694
 
442
695
  # @conditionaljit
443
696
  def cross_mutual_info(
444
- x,
445
- y,
446
- returnaxis=False,
447
- negsteps=-1,
448
- possteps=-1,
449
- locs=None,
450
- Fs=1.0,
451
- norm=True,
452
- madnorm=False,
453
- windowfunc="None",
454
- bins=-1,
455
- prebin=True,
456
- sigma=0.25,
457
- fast=True,
458
- ):
459
- """Calculate cross-mutual information between two 1D arrays.
697
+ x: NDArray[np.floating[Any]],
698
+ y: NDArray[np.floating[Any]],
699
+ returnaxis: bool = False,
700
+ negsteps: int = -1,
701
+ possteps: int = -1,
702
+ locs: Optional[NDArray] = None,
703
+ Fs: float = 1.0,
704
+ norm: bool = True,
705
+ madnorm: bool = False,
706
+ windowfunc: str = "None",
707
+ bins: int = -1,
708
+ prebin: bool = True,
709
+ sigma: float = 0.25,
710
+ fast: bool = True,
711
+ ) -> Union[NDArray, Tuple[NDArray, NDArray, int]]:
712
+ """
713
+ Calculate cross-mutual information between two 1D arrays.
714
+
715
+ This function computes the cross-mutual information (MI) between two signals
716
+ `x` and `y` at various time lags or specified offsets. It supports normalization,
717
+ windowing, and histogram smoothing for robust estimation.
718
+
460
719
  Parameters
461
720
  ----------
462
- x : 1D array
463
- first variable
464
- y : 1D array
465
- second variable. The length of y must by >= the length of x
466
- returnaxis : bool
467
- set to True to return the time axis
468
- negsteps: int
469
- possteps: int
470
- locs : list
471
- a set of offsets at which to calculate the cross mutual information
472
- Fs=1.0,
473
- norm : bool
474
- calculate normalized MI at each offset
475
- madnorm : bool
476
- set to True to normalize cross MI waveform by it's median average deviate
477
- windowfunc : str
478
- name of the window function to apply to input vectors prior to MI calculation
479
- bins : int
480
- number of bins in each dimension of the 2D histogram. Set to -1 to set automatically
481
- prebin : bool
482
- set to true to cache 2D histogram for all offsets
483
- sigma : float
484
- histogram smoothing kernel
485
- fast: bool
486
- apply speed optimizations
721
+ x : NDArray[np.floating[Any]]
722
+ First variable (signal).
723
+ y : NDArray[np.floating[Any]]
724
+ Second variable (signal). Must have length >= length of `x`.
725
+ returnaxis : bool, optional
726
+ If True, return the time axis along with the MI values. Default is False.
727
+ negsteps : int, optional
728
+ Number of negative time steps to compute MI for. If -1, uses default based on signal length.
729
+ Default is -1.
730
+ possteps : int, optional
731
+ Number of positive time steps to compute MI for. If -1, uses default based on signal length.
732
+ Default is -1.
733
+ locs : ndarray of int, optional
734
+ Specific time offsets at which to compute MI. If None, uses `negsteps` and `possteps`.
735
+ Default is None.
736
+ Fs : float, optional
737
+ Sampling frequency. Used when `returnaxis` is True. Default is 1.0.
738
+ norm : bool, optional
739
+ If True, normalize the MI values. Default is True.
740
+ madnorm : bool, optional
741
+ If True, normalize the MI waveform by its median absolute deviation (MAD).
742
+ Default is False.
743
+ windowfunc : str, optional
744
+ Name of the window function to apply to input signals before MI calculation.
745
+ Default is "None".
746
+ bins : int, optional
747
+ Number of bins for the 2D histogram. If -1, automatically determined.
748
+ Default is -1.
749
+ prebin : bool, optional
750
+ If True, precompute and cache the 2D histogram for all offsets.
751
+ Default is True.
752
+ sigma : float, optional
753
+ Standard deviation of the Gaussian smoothing kernel applied to the histogram.
754
+ Default is 0.25.
755
+ fast : bool, optional
756
+ If True, apply speed optimizations. Default is True.
487
757
 
488
758
  Returns
489
759
  -------
490
- if returnaxis is True:
491
- thexmi_x : 1D array
492
- the set of offsets at which cross mutual information is calcuated
493
- thexmi_y : 1D array
494
- the set of cross mutual information values
495
- len(thexmi_x): int
496
- the number of cross mutual information values returned
497
- else:
498
- thexmi_y : 1D array
499
- the set of cross mutual information values
760
+ ndarray or tuple of ndarray
761
+ If `returnaxis` is False:
762
+ The set of cross-mutual information values.
763
+ If `returnaxis` is True:
764
+ Tuple of (time_axis, mi_values, num_values), where:
765
+ - time_axis : ndarray of float
766
+ Time axis corresponding to the MI values.
767
+ - mi_values : ndarray of float
768
+ Cross-mutual information values.
769
+ - num_values : int
770
+ Number of MI values returned.
500
771
 
772
+ Notes
773
+ -----
774
+ - The function normalizes input signals using detrending and optional windowing.
775
+ - Cross-mutual information is computed using 2D histogram estimation and
776
+ mutual information calculation.
777
+ - If `prebin` is True, the 2D histogram is precomputed for efficiency.
778
+
779
+ Examples
780
+ --------
781
+ >>> import numpy as np
782
+ >>> x = np.random.randn(100)
783
+ >>> y = np.random.randn(100)
784
+ >>> mi = cross_mutual_info(x, y)
785
+ >>> mi_axis, mi_vals, num = cross_mutual_info(x, y, returnaxis=True, Fs=10)
501
786
  """
502
-
503
787
  normx = tide_math.corrnormalize(x, detrendorder=1, windowfunc=windowfunc)
504
788
  normy = tide_math.corrnormalize(y, detrendorder=1, windowfunc=windowfunc)
505
789
 
@@ -538,35 +822,59 @@ def cross_mutual_info(
538
822
  else:
539
823
  destloc += 1
540
824
  if i < 0:
541
- thexmi_y[destloc] = mutual_info_2d(
542
- normx[: i + len(normy)],
543
- normy[-i:],
544
- bins=bins2d,
545
- normalized=norm,
546
- fast=fast,
547
- sigma=sigma,
548
- )
825
+ if fast:
826
+ thexmi_y[destloc] = mutual_info_2d_fast(
827
+ normx[: i + len(normy)],
828
+ normy[-i:],
829
+ bins2d,
830
+ normalized=norm,
831
+ sigma=sigma,
832
+ )
833
+ else:
834
+ thexmi_y[destloc] = mutual_info_2d(
835
+ normx[: i + len(normy)],
836
+ normy[-i:],
837
+ bins2d,
838
+ normalized=norm,
839
+ sigma=sigma,
840
+ )
549
841
  elif i == 0:
550
- thexmi_y[destloc] = mutual_info_2d(
551
- normx,
552
- normy,
553
- bins=bins2d,
554
- normalized=norm,
555
- fast=fast,
556
- sigma=sigma,
557
- )
842
+ if fast:
843
+ thexmi_y[destloc] = mutual_info_2d_fast(
844
+ normx,
845
+ normy,
846
+ bins2d,
847
+ normalized=norm,
848
+ sigma=sigma,
849
+ )
850
+ else:
851
+ thexmi_y[destloc] = mutual_info_2d(
852
+ normx,
853
+ normy,
854
+ bins2d,
855
+ normalized=norm,
856
+ sigma=sigma,
857
+ )
558
858
  else:
559
- thexmi_y[destloc] = mutual_info_2d(
560
- normx[i:],
561
- normy[: len(normy) - i],
562
- bins=bins2d,
563
- normalized=norm,
564
- fast=fast,
565
- sigma=sigma,
566
- )
859
+ if fast:
860
+ thexmi_y[destloc] = mutual_info_2d_fast(
861
+ normx[i:],
862
+ normy[: len(normy) - i],
863
+ bins2d,
864
+ normalized=norm,
865
+ sigma=sigma,
866
+ )
867
+ else:
868
+ thexmi_y[destloc] = mutual_info_2d(
869
+ normx[i:],
870
+ normy[: len(normy) - i],
871
+ bins2d,
872
+ normalized=norm,
873
+ sigma=sigma,
874
+ )
567
875
 
568
876
  if madnorm:
569
- thexmi_y = tide_math.madnormalize(thexmi_y)
877
+ thexmi_y = tide_math.madnormalize(thexmi_y)[0]
570
878
 
571
879
  if returnaxis:
572
880
  if locs is None:
@@ -582,77 +890,94 @@ def cross_mutual_info(
582
890
  return thexmi_y
583
891
 
584
892
 
585
- def mutual_info_to_r(themi, d=1):
586
- """Convert mutual information to Pearson product-moment correlation."""
587
- return np.power(1.0 - np.exp(-2.0 * themi / d), -0.5)
893
+ def mutual_info_to_r(themi: float, d: int = 1) -> float:
894
+ """
895
+ Convert mutual information to Pearson product-moment correlation.
588
896
 
897
+ This function transforms mutual information values into Pearson correlation coefficients
898
+ using the relationship derived from the assumption of joint Gaussian distributions.
589
899
 
590
- def dtw_distance(s1, s2):
591
- # Dynamic time warping function written by GPT-4
592
- # Get the lengths of the two input sequences
593
- n, m = len(s1), len(s2)
900
+ Parameters
901
+ ----------
902
+ themi : float
903
+ Mutual information value (in nats) to be converted.
904
+ d : int, default=1
905
+ Dimensionality of the random variables. For single-dimensional variables, d=1.
906
+ For multi-dimensional variables, d represents the number of dimensions.
594
907
 
595
- # Initialize a (n+1) x (m+1) matrix with zeros
596
- DTW = np.zeros((n + 1, m + 1))
908
+ Returns
909
+ -------
910
+ float
911
+ Pearson product-moment correlation coefficient corresponding to the input
912
+ mutual information value. The result is in the range [0, 1].
597
913
 
598
- # Set the first row and first column of the matrix to infinity, since
599
- # the first element of each sequence cannot be aligned with an empty sequence
600
- DTW[1:, 0] = np.inf
601
- DTW[0, 1:] = np.inf
914
+ Notes
915
+ -----
916
+ The transformation is based on the formula:
917
+ r = (1 - exp(-2*MI/d))^(-1/2)
602
918
 
603
- # Compute the DTW distance by iteratively filling in the matrix
604
- for i in range(1, n + 1):
605
- for j in range(1, m + 1):
606
- # Compute the cost of aligning the i-th element of s1 with the j-th element of s2
607
- cost = abs(s1[i - 1] - s2[j - 1])
919
+ This approximation is valid under the assumption that the variables follow
920
+ a joint Gaussian distribution. For non-Gaussian distributions, the relationship
921
+ may not hold exactly.
608
922
 
609
- # Compute the minimum cost of aligning the first i-1 elements of s1 with the first j elements of s2,
610
- # the first i elements of s1 with the first j-1 elements of s2, and the first i-1 elements of s1
611
- # with the first j-1 elements of s2, and add this to the cost of aligning the i-th element of s1
612
- # with the j-th element of s2
613
- DTW[i, j] = cost + np.min([DTW[i - 1, j], DTW[i, j - 1], DTW[i - 1, j - 1]])
923
+ Examples
924
+ --------
925
+ >>> mutual_info_to_r(1.0)
926
+ 0.8416445342422313
614
927
 
615
- # Return the DTW distance between the two sequences, which is the value in the last cell of the matrix
616
- return DTW[n, m]
928
+ >>> mutual_info_to_r(2.0, d=2)
929
+ 0.9640275800758169
930
+ """
931
+ return np.power(1.0 - np.exp(-2.0 * themi / d), -0.5)
617
932
 
618
933
 
619
- def delayedcorr(data1, data2, delayval, timestep):
620
- """Calculate correlation between two 1D arrays, at specific delay.
934
+ def delayedcorr(
935
+ data1: NDArray, data2: NDArray, delayval: float, timestep: float
936
+ ) -> Tuple[float, float]:
937
+ return sp.stats.pearsonr(
938
+ data1, tide_resample.timeshift(data2, delayval / timestep, 30).statistic
939
+ )
621
940
 
622
- Parameters
623
- ----------
624
- data1
625
- data2
626
- delayval
627
- timestep
628
941
 
629
- Returns
630
- -------
631
- corr
942
+ def cepstraldelay(
943
+ data1: NDArray, data2: NDArray, timestep: float, displayplots: bool = True
944
+ ) -> float:
632
945
  """
633
- return sp.stats.pearsonr(data1, tide_resample.timeshift(data2, delayval / timestep, 30)[0])
946
+ Calculate correlation between two datasets with a time delay applied to the second dataset.
634
947
 
635
-
636
- def cepstraldelay(data1, data2, timestep, displayplots=True):
637
- """
638
- Estimate delay between two signals using Choudhary's cepstral analysis method.
948
+ This function computes the Pearson correlation coefficient between two datasets,
949
+ where the second dataset is time-shifted by a specified delay before correlation
950
+ is calculated. The time shift is applied using the tide_resample.timeshift function.
639
951
 
640
952
  Parameters
641
953
  ----------
642
- data1
643
- data2
644
- timestep
645
- displayplots
954
+ data1 : NDArray
955
+ First dataset for correlation calculation.
956
+ data2 : NDArray
957
+ Second dataset to be time-shifted and correlated with data1.
958
+ delayval : float
959
+ Time delay to apply to data2, specified in the same units as timestep.
960
+ timestep : float
961
+ Time step of the datasets, used to convert delayval to sample units.
646
962
 
647
963
  Returns
648
964
  -------
649
- arr
965
+ Tuple[float, float]
966
+ Pearson correlation coefficient and p-value from the correlation test.
650
967
 
651
- References
652
- ----------
653
- * Choudhary, H., Bahl, R. & Kumar, A.
654
- Inter-sensor Time Delay Estimation using cepstrum of sum and difference signals in
655
- underwater multipath environment. in 1-7 (IEEE, 2015). doi:10.1109/UT.2015.7108308
968
+ Notes
969
+ -----
970
+ The delayval is converted to sample units by dividing by timestep before
971
+ applying the time shift. The tide_resample.timeshift function is used internally
972
+ with a window parameter of 30.
973
+
974
+ Examples
975
+ --------
976
+ >>> import numpy as np
977
+ >>> data1 = np.array([1, 2, 3, 4, 5])
978
+ >>> data2 = np.array([2, 3, 4, 5, 6])
979
+ >>> corr, p_value = delayedcorr(data1, data2, delay=1.0, timestep=0.1)
980
+ >>> print(f"Correlation: {corr:.3f}")
656
981
  """
657
982
  ceps1, _ = tide_math.complex_cepstrum(data1)
658
983
  ceps2, _ = tide_math.complex_cepstrum(data2)
@@ -690,22 +1015,75 @@ def cepstraldelay(data1, data2, timestep, displayplots=True):
690
1015
 
691
1016
 
692
1017
  class AliasedCorrelator:
693
- """An aliased correlator.
1018
+ def __init__(self, hiressignal, hires_Fs, numsteps):
1019
+ """
1020
+ Initialize the object with high-resolution signal parameters.
694
1021
 
695
- Parameters
696
- ----------
697
- hiressignal : 1D array
698
- The unaliased waveform to match
699
- hires_Fs : float
700
- The sample rate of the unaliased waveform
701
- numsteps : int
702
- Number of distinct slice acquisition times within the TR.
703
- """
1022
+ Parameters
1023
+ ----------
1024
+ hiressignal : array-like
1025
+ High-resolution signal data to be processed.
1026
+ hires_Fs : float
1027
+ Sampling frequency of the high-resolution signal in Hz.
1028
+ numsteps : int
1029
+ Number of steps for signal processing.
704
1030
 
705
- def __init__(self, hiressignal, hires_Fs, numsteps):
706
- self.hiressignal = tide_math.corrnormalize(hiressignal)
1031
+ Returns
1032
+ -------
1033
+ None
1034
+ This method initializes the object attributes and does not return any value.
1035
+
1036
+ Notes
1037
+ -----
1038
+ This constructor sets up the basic configuration for high-resolution signal processing
1039
+ by storing the sampling frequency and number of steps, then calls sethiressignal()
1040
+ to process the input signal.
1041
+
1042
+ Examples
1043
+ --------
1044
+ >>> obj = MyClass(hiressignal, hires_Fs=44100, numsteps=100)
1045
+ >>> obj.hires_Fs
1046
+ 44100
1047
+ >>> obj.numsteps
1048
+ 100
1049
+ """
707
1050
  self.hires_Fs = hires_Fs
708
1051
  self.numsteps = numsteps
1052
+ self.sethiressignal(hiressignal)
1053
+
1054
+ def sethiressignal(self, hiressignal):
1055
+ """
1056
+ Set high resolution signal and compute related parameters.
1057
+
1058
+ This method processes the high resolution signal by normalizing it and computing
1059
+ correlation-related parameters including correlation length and correlation x-axis.
1060
+
1061
+ Parameters
1062
+ ----------
1063
+ hiressignal : array-like
1064
+ High resolution signal data to be processed and normalized.
1065
+
1066
+ Returns
1067
+ -------
1068
+ None
1069
+ This method modifies the instance attributes in-place and does not return a value.
1070
+
1071
+ Notes
1072
+ -----
1073
+ The method performs correlation normalization using `tide_math.corrnormalize` and
1074
+ computes the correlation length as `len(self.hiressignal) * 2 + 1`. The correlation
1075
+ x-axis is computed based on the sampling frequency (`self.hires_Fs`) and the length
1076
+ of the high resolution signal.
1077
+
1078
+ Examples
1079
+ --------
1080
+ >>> obj.sethiressignal(hiressignal_data)
1081
+ >>> print(obj.corrlen)
1082
+ 1001
1083
+ >>> print(obj.corrx.shape)
1084
+ (1001,)
1085
+ """
1086
+ self.hiressignal = tide_math.corrnormalize(hiressignal)
709
1087
  self.corrlen = len(self.hiressignal) * 2 + 1
710
1088
  self.corrx = (
711
1089
  np.linspace(0.0, self.corrlen, num=self.corrlen) / self.hires_Fs
@@ -713,47 +1091,132 @@ class AliasedCorrelator:
713
1091
  )
714
1092
 
715
1093
  def getxaxis(self):
1094
+ """
1095
+ Return the x-axis correction value.
1096
+
1097
+ This method retrieves the correction value applied to the x-axis.
1098
+
1099
+ Returns
1100
+ -------
1101
+ float or int
1102
+ The correction value for the x-axis stored in `self.corrx`.
1103
+
1104
+ Notes
1105
+ -----
1106
+ The returned value represents the x-axis correction that has been
1107
+ previously computed or set in the object's `corrx` attribute.
1108
+
1109
+ Examples
1110
+ --------
1111
+ >>> obj = MyClass()
1112
+ >>> obj.corrx = 5.0
1113
+ >>> obj.getxaxis()
1114
+ 5.0
1115
+ """
716
1116
  return self.corrx
717
1117
 
718
1118
  def apply(self, loressignal, offset, debug=False):
719
- """Apply correlator to aliased waveform.
1119
+ """
1120
+ Apply correlator to aliased waveform.
1121
+
720
1122
  NB: Assumes the highres frequency is an integral multiple of the lowres frequency
1123
+
721
1124
  Parameters
722
1125
  ----------
723
- loressignal: 1D array
1126
+ loressignal : 1D array
724
1127
  The aliased waveform to match
725
- offset: int
1128
+ offset : int
726
1129
  Integer offset to apply to the upsampled lowressignal (to account for slice time offset)
727
- debug: bool, optional
1130
+ debug : bool, optional
728
1131
  Whether to print diagnostic information
1132
+
729
1133
  Returns
730
1134
  -------
731
- corrfunc: 1D array
1135
+ corrfunc : 1D array
732
1136
  The full correlation function
1137
+
1138
+ Notes
1139
+ -----
1140
+ This function applies a correlator to an aliased waveform by:
1141
+ 1. Creating an upsampled version of the high-resolution signal
1142
+ 2. Inserting the low-resolution signal at the specified offset
1143
+ 3. Computing the cross-correlation between the two signals
1144
+ 4. Normalizing the result by the square root of the number of steps
1145
+
1146
+ Examples
1147
+ --------
1148
+ >>> result = correlator.apply(signal, offset=5, debug=True)
1149
+ >>> print(result.shape)
1150
+ (len(highres_signal),)
733
1151
  """
734
1152
  if debug:
735
1153
  print(offset, self.numsteps)
736
- osvec = self.hiressignal * 0.0
1154
+ osvec = np.zeros_like(self.hiressignal)
737
1155
  osvec[offset :: self.numsteps] = loressignal[:]
738
- corrfunc = (
739
- tide_corr.fastcorrelate(tide_math.corrnormalize(osvec), self.hiressignal)
740
- * self.numsteps
1156
+ corrfunc = fastcorrelate(tide_math.corrnormalize(osvec), self.hiressignal) * np.sqrt(
1157
+ self.numsteps
741
1158
  )
742
1159
  return corrfunc
743
1160
 
744
1161
 
745
- def arbcorr(
746
- input1,
747
- Fs1,
748
- input2,
749
- Fs2,
750
- start1=0.0,
751
- start2=0.0,
752
- windowfunc="hamming",
753
- method="univariate",
754
- debug=False,
755
- ):
756
- """Calculate something."""
1162
+ def matchsamplerates(
1163
+ input1: NDArray,
1164
+ Fs1: float,
1165
+ input2: NDArray,
1166
+ Fs2: float,
1167
+ method: str = "univariate",
1168
+ debug: bool = False,
1169
+ ) -> Tuple[NDArray, NDArray, float]:
1170
+ """
1171
+ Match sampling rates of two input arrays by upsampling the lower sampling rate signal.
1172
+
1173
+ This function takes two input arrays with potentially different sampling rates and
1174
+ ensures they have the same sampling rate by upsampling the signal with the lower
1175
+ sampling rate to match the higher one. The function preserves the original data
1176
+ while adjusting the sampling rate for compatibility.
1177
+
1178
+ Parameters
1179
+ ----------
1180
+ input1 : NDArray
1181
+ First input array to be processed.
1182
+ Fs1 : float
1183
+ Sampling frequency of the first input array (Hz).
1184
+ input2 : NDArray
1185
+ Second input array to be processed.
1186
+ Fs2 : float
1187
+ Sampling frequency of the second input array (Hz).
1188
+ method : str, optional
1189
+ Resampling method to use, by default "univariate".
1190
+ See `tide_resample.upsample` for available methods.
1191
+ debug : bool, optional
1192
+ Enable debug output, by default False.
1193
+
1194
+ Returns
1195
+ -------
1196
+ Tuple[NDArray, NDArray, float]
1197
+ Tuple containing:
1198
+ - matchedinput1: First input array upsampled to match the sampling rate
1199
+ - matchedinput2: Second input array upsampled to match the sampling rate
1200
+ - corrFs: The common sampling frequency used for both outputs
1201
+
1202
+ Notes
1203
+ -----
1204
+ - If sampling rates are equal, no upsampling is performed
1205
+ - The function always upsamples to the higher sampling rate
1206
+ - The upsampling is performed using the `tide_resample.upsample` function
1207
+ - Both output arrays will have the same length and sampling rate
1208
+
1209
+ Examples
1210
+ --------
1211
+ >>> import numpy as np
1212
+ >>> input1 = np.array([1, 2, 3, 4])
1213
+ >>> input2 = np.array([5, 6, 7])
1214
+ >>> Fs1 = 10.0
1215
+ >>> Fs2 = 5.0
1216
+ >>> matched1, matched2, common_fs = matchsamplerates(input1, Fs1, input2, Fs2)
1217
+ >>> print(common_fs)
1218
+ 10.0
1219
+ """
757
1220
  if Fs1 > Fs2:
758
1221
  corrFs = Fs1
759
1222
  matchedinput1 = input1
@@ -763,9 +1226,90 @@ def arbcorr(
763
1226
  matchedinput1 = tide_resample.upsample(input1, Fs1, corrFs, method=method, debug=debug)
764
1227
  matchedinput2 = input2
765
1228
  else:
766
- corrFs = Fs1
1229
+ corrFs = Fs2
767
1230
  matchedinput1 = input1
768
1231
  matchedinput2 = input2
1232
+ return matchedinput1, matchedinput2, corrFs
1233
+
1234
+
1235
+ def arbcorr(
1236
+ input1: NDArray,
1237
+ Fs1: float,
1238
+ input2: NDArray,
1239
+ Fs2: float,
1240
+ start1: float = 0.0,
1241
+ start2: float = 0.0,
1242
+ windowfunc: str = "hamming",
1243
+ method: str = "univariate",
1244
+ debug: bool = False,
1245
+ ) -> Tuple[NDArray, NDArray, float, int]:
1246
+ """
1247
+ Compute the cross-correlation between two signals with arbitrary sampling rates.
1248
+
1249
+ This function performs cross-correlation between two input signals after
1250
+ matching their sampling rates. It applies normalization and uses FFT-based
1251
+ convolution for efficient computation. The result includes the time lag axis,
1252
+ cross-correlation values, the matched sampling frequency, and the index of
1253
+ the zero-lag point.
1254
+
1255
+ Parameters
1256
+ ----------
1257
+ input1 : NDArray
1258
+ First input signal array.
1259
+ Fs1 : float
1260
+ Sampling frequency of the first signal (Hz).
1261
+ input2 : NDArray
1262
+ Second input signal array.
1263
+ Fs2 : float
1264
+ Sampling frequency of the second signal (Hz).
1265
+ start1 : float, optional
1266
+ Start time of the first signal (default is 0.0).
1267
+ start2 : float, optional
1268
+ Start time of the second signal (default is 0.0).
1269
+ windowfunc : str, optional
1270
+ Window function used for normalization (default is "hamming").
1271
+ method : str, optional
1272
+ Method used for matching sampling rates (default is "univariate").
1273
+ debug : bool, optional
1274
+ If True, enables debug logging (default is False).
1275
+
1276
+ Returns
1277
+ -------
1278
+ tuple
1279
+ A tuple containing:
1280
+ - thexcorr_x : NDArray
1281
+ Time lag axis for the cross-correlation (seconds).
1282
+ - thexcorr_y : NDArray
1283
+ Cross-correlation values.
1284
+ - corrFs : float
1285
+ Matched sampling frequency used for the computation (Hz).
1286
+ - zeroloc : int
1287
+ Index corresponding to the zero-lag point in the cross-correlation.
1288
+
1289
+ Notes
1290
+ -----
1291
+ - The function upsamples the signals to the higher of the two sampling rates.
1292
+ - Normalization is applied using a detrend order of 1 and the specified window function.
1293
+ - The cross-correlation is computed using FFT convolution for efficiency.
1294
+ - The zero-lag point is determined as the index of the minimum absolute value in the time axis.
1295
+
1296
+ Examples
1297
+ --------
1298
+ >>> import numpy as np
1299
+ >>> signal1 = np.random.randn(1000)
1300
+ >>> signal2 = np.random.randn(1000)
1301
+ >>> lags, corr_vals, fs, zero_idx = arbcorr(signal1, 10.0, signal2, 10.0)
1302
+ >>> print(f"Zero-lag index: {zero_idx}")
1303
+ """
1304
+ # upsample to the higher frequency of the two
1305
+ matchedinput1, matchedinput2, corrFs = matchsamplerates(
1306
+ input1,
1307
+ Fs1,
1308
+ input2,
1309
+ Fs2,
1310
+ method=method,
1311
+ debug=debug,
1312
+ )
769
1313
  norm1 = tide_math.corrnormalize(matchedinput1, detrendorder=1, windowfunc=windowfunc)
770
1314
  norm2 = tide_math.corrnormalize(matchedinput2, detrendorder=1, windowfunc=windowfunc)
771
1315
  thexcorr_y = signal.fftconvolve(norm1, norm2[::-1], mode="full")
@@ -784,13 +1328,66 @@ def arbcorr(
784
1328
 
785
1329
 
786
1330
  def faststcorrelate(
787
- input1, input2, windowtype="hann", nperseg=32, weighting="None", displayplots=False
788
- ):
789
- """Perform correlation between short-time Fourier transformed arrays."""
1331
+ input1: NDArray,
1332
+ input2: NDArray,
1333
+ windowtype: str = "hann",
1334
+ nperseg: int = 32,
1335
+ weighting: str = "None",
1336
+ displayplots: bool = False,
1337
+ ) -> Tuple[NDArray, NDArray, NDArray]:
1338
+ """
1339
+ Perform correlation between short-time Fourier transformed arrays.
1340
+
1341
+ This function computes the short-time cross-correlation between two input signals
1342
+ using their short-time Fourier transforms (STFTs). It applies a windowing function
1343
+ to each signal, computes the STFT, and then performs correlation in the frequency
1344
+ domain before inverse transforming back to the time domain. The result is normalized
1345
+ by the auto-correlation of each signal.
1346
+
1347
+ Parameters
1348
+ ----------
1349
+ input1 : ndarray
1350
+ First input signal array.
1351
+ input2 : ndarray
1352
+ Second input signal array.
1353
+ windowtype : str, optional
1354
+ Type of window to apply. Default is 'hann'.
1355
+ nperseg : int, optional
1356
+ Length of each segment for STFT. Default is 32.
1357
+ weighting : str, optional
1358
+ Weighting method for the STFT. Default is 'None'.
1359
+ displayplots : bool, optional
1360
+ If True, display plots (not implemented in current version). Default is False.
1361
+
1362
+ Returns
1363
+ -------
1364
+ corrtimes : ndarray
1365
+ Time shifts corresponding to the correlation results.
1366
+ times : ndarray
1367
+ Time indices of the STFT.
1368
+ stcorr : ndarray
1369
+ Short-time cross-correlation values.
1370
+
1371
+ Notes
1372
+ -----
1373
+ The function uses `scipy.signal.stft` to compute the short-time Fourier transform
1374
+ of both input signals. The correlation is computed in the frequency domain and
1375
+ normalized by the square root of the auto-correlation of each signal.
1376
+
1377
+ Examples
1378
+ --------
1379
+ >>> import numpy as np
1380
+ >>> from scipy import signal
1381
+ >>> t = np.linspace(0, 1, 100)
1382
+ >>> x1 = np.sin(2 * np.pi * 5 * t)
1383
+ >>> x2 = np.sin(2 * np.pi * 5 * t + 0.1)
1384
+ >>> corrtimes, times, corr = faststcorrelate(x1, x2)
1385
+ >>> print(corr.shape)
1386
+ (32, 100)
1387
+ """
790
1388
  nfft = nperseg
791
1389
  noverlap = nperseg - 1
792
1390
  onesided = False
793
- boundary = "even"
794
1391
  freqs, times, thestft1 = signal.stft(
795
1392
  input1,
796
1393
  fs=1.0,
@@ -800,7 +1397,7 @@ def faststcorrelate(
800
1397
  nfft=nfft,
801
1398
  detrend="linear",
802
1399
  return_onesided=onesided,
803
- boundary=boundary,
1400
+ boundary="even",
804
1401
  padded=True,
805
1402
  axis=-1,
806
1403
  )
@@ -814,7 +1411,7 @@ def faststcorrelate(
814
1411
  nfft=nfft,
815
1412
  detrend="linear",
816
1413
  return_onesided=onesided,
817
- boundary=boundary,
1414
+ boundary="even",
818
1415
  padded=True,
819
1416
  axis=-1,
820
1417
  )
@@ -840,7 +1437,40 @@ def faststcorrelate(
840
1437
  return corrtimes, times, stcorr
841
1438
 
842
1439
 
843
- def primefacs(thelen):
1440
+ def primefacs(thelen: int) -> list:
1441
+ """
1442
+ Compute the prime factorization of a given integer.
1443
+
1444
+ Parameters
1445
+ ----------
1446
+ thelen : int
1447
+ The positive integer to factorize. Must be greater than 0.
1448
+
1449
+ Returns
1450
+ -------
1451
+ list
1452
+ A list of prime factors of `thelen`, sorted in ascending order.
1453
+ Each factor appears as many times as its multiplicity in the
1454
+ prime factorization.
1455
+
1456
+ Notes
1457
+ -----
1458
+ This function implements trial division algorithm to find prime factors.
1459
+ The algorithm starts with the smallest prime (2) and continues with
1460
+ increasing integers until the square root of the remaining number.
1461
+ The final remaining number (if greater than 1) is also a prime factor.
1462
+
1463
+ Examples
1464
+ --------
1465
+ >>> primefacs(12)
1466
+ [2, 2, 3]
1467
+
1468
+ >>> primefacs(17)
1469
+ [17]
1470
+
1471
+ >>> primefacs(100)
1472
+ [2, 2, 5, 5]
1473
+ """
844
1474
  i = 2
845
1475
  factors = []
846
1476
  while i * i <= thelen:
@@ -853,40 +1483,62 @@ def primefacs(thelen):
853
1483
  return factors
854
1484
 
855
1485
 
856
- def optfftlen(thelen):
857
- return thelen
858
-
859
-
860
1486
  def fastcorrelate(
861
- input1,
862
- input2,
863
- usefft=True,
864
- zeropadding=0,
865
- weighting="None",
866
- compress=False,
867
- displayplots=False,
868
- debug=False,
869
- ):
870
- """Perform a fast correlation between two arrays.
1487
+ input1: NDArray,
1488
+ input2: NDArray,
1489
+ usefft: bool = True,
1490
+ zeropadding: int = 0,
1491
+ weighting: str = "None",
1492
+ compress: bool = False,
1493
+ displayplots: bool = False,
1494
+ debug: bool = False,
1495
+ ) -> NDArray:
1496
+ """
1497
+ Perform a fast correlation between two arrays.
1498
+
1499
+ This function computes the cross-correlation of two input arrays, with options
1500
+ for using FFT-based convolution or direct correlation, as well as padding and
1501
+ weighting schemes.
871
1502
 
872
1503
  Parameters
873
1504
  ----------
874
- input1
875
- input2
876
- usefft
877
- zeropadding
878
- weighting
879
- compress
880
- displayplots
881
- debug
1505
+ input1 : ndarray
1506
+ First input array to correlate.
1507
+ input2 : ndarray
1508
+ Second input array to correlate.
1509
+ usefft : bool, optional
1510
+ If True, use FFT-based convolution for faster computation. Default is True.
1511
+ zeropadding : int, optional
1512
+ Zero-padding length. If 0, no padding is applied. If negative, automatic
1513
+ padding is applied. If positive, explicit padding is applied. Default is 0.
1514
+ weighting : str, optional
1515
+ Type of weighting to apply. If "None", no weighting is applied. Default is "None".
1516
+ compress : bool, optional
1517
+ If True and `weighting` is not "None", compress the result. Default is False.
1518
+ displayplots : bool, optional
1519
+ If True, display plots of padded inputs and correlation result. Default is False.
1520
+ debug : bool, optional
1521
+ If True, enable debug output. Default is False.
882
1522
 
883
1523
  Returns
884
1524
  -------
885
- corr
1525
+ ndarray
1526
+ The cross-correlation of `input1` and `input2`. The length of the output is
1527
+ `len(input1) + len(input2) - 1`.
886
1528
 
887
1529
  Notes
888
1530
  -----
889
- From http://stackoverflow.com/questions/12323959/fast-cross-correlation-method-in-python.
1531
+ This implementation is based on the method described at:
1532
+ http://stackoverflow.com/questions/12323959/fast-cross-correlation-method-in-python
1533
+
1534
+ Examples
1535
+ --------
1536
+ >>> import numpy as np
1537
+ >>> a = np.array([1, 2, 3])
1538
+ >>> b = np.array([0, 1, 0])
1539
+ >>> result = fastcorrelate(a, b)
1540
+ >>> print(result)
1541
+ [0. 1. 2. 3. 0.]
890
1542
  """
891
1543
  len1 = len(input1)
892
1544
  len2 = len(input2)
@@ -957,17 +1609,36 @@ def fastcorrelate(
957
1609
  return np.correlate(paddedinput1, paddedinput2, mode="full")
958
1610
 
959
1611
 
960
- def _centered(arr, newsize):
961
- """Return the center newsize portion of the array.
1612
+ def _centered(arr: NDArray, newsize: Union[int, NDArray]) -> NDArray:
1613
+ """
1614
+ Extract a centered subset of an array.
962
1615
 
963
1616
  Parameters
964
1617
  ----------
965
- arr
966
- newsize
1618
+ arr : array_like
1619
+ Input array from which to extract the centered subset.
1620
+ newsize : int or array_like
1621
+ The size of the output array. If int, the same size is used for all dimensions.
1622
+ If array_like, specifies the size for each dimension.
967
1623
 
968
1624
  Returns
969
1625
  -------
970
- arr
1626
+ ndarray
1627
+ Centered subset of the input array with the specified size.
1628
+
1629
+ Notes
1630
+ -----
1631
+ The function extracts a subset from the center of the input array. If the requested
1632
+ size is larger than the current array size in any dimension, the result will be
1633
+ padded with zeros (or the array will be truncated from the center).
1634
+
1635
+ Examples
1636
+ --------
1637
+ >>> import numpy as np
1638
+ >>> arr = np.arange(24).reshape(4, 6)
1639
+ >>> _centered(arr, (2, 3))
1640
+ array([[ 7, 8, 9],
1641
+ [13, 14, 15]])
971
1642
  """
972
1643
  newsize = np.asarray(newsize)
973
1644
  currsize = np.array(arr.shape)
@@ -977,22 +1648,36 @@ def _centered(arr, newsize):
977
1648
  return arr[tuple(myslice)]
978
1649
 
979
1650
 
980
- def _check_valid_mode_shapes(shape1, shape2):
981
- """Check that two shapes are 'valid' with respect to one another.
982
-
983
- Specifically, this checks that each item in one tuple is larger than or
984
- equal to corresponding item in another tuple.
1651
+ def _check_valid_mode_shapes(shape1: Tuple, shape2: Tuple) -> None:
1652
+ """
1653
+ Check that shape1 is valid for 'valid' mode convolution with shape2.
985
1654
 
986
1655
  Parameters
987
1656
  ----------
988
- shape1
989
- shape2
990
-
991
- Raises
992
- ------
993
- ValueError
994
- If at least one item in the first shape is not larger than or equal to
995
- the corresponding item in the second one.
1657
+ shape1 : Tuple
1658
+ First shape tuple to compare
1659
+ shape2 : Tuple
1660
+ Second shape tuple to compare
1661
+
1662
+ Returns
1663
+ -------
1664
+ None
1665
+ This function does not return anything but raises ValueError if condition is not met
1666
+
1667
+ Notes
1668
+ -----
1669
+ This function is used to validate that the first shape has at least as many
1670
+ elements as the second shape in every dimension, which is required for
1671
+ 'valid' mode convolution operations.
1672
+
1673
+ Examples
1674
+ --------
1675
+ >>> _check_valid_mode_shapes((10, 10), (5, 5))
1676
+ >>> _check_valid_mode_shapes((10, 10), (10, 5))
1677
+ >>> _check_valid_mode_shapes((5, 5), (10, 5))
1678
+ Traceback (most recent call last):
1679
+ ...
1680
+ ValueError: in1 should have at least as many items as in2 in every dimension for 'valid' mode.
996
1681
  """
997
1682
  for d1, d2 in zip(shape1, shape2):
998
1683
  if not d1 >= d2:
@@ -1003,22 +1688,29 @@ def _check_valid_mode_shapes(shape1, shape2):
1003
1688
 
1004
1689
 
1005
1690
  def convolve_weighted_fft(
1006
- in1, in2, mode="full", weighting="None", compress=False, displayplots=False
1007
- ):
1008
- """Convolve two N-dimensional arrays using FFT.
1691
+ in1: NDArray[np.floating[Any]],
1692
+ in2: NDArray[np.floating[Any]],
1693
+ mode: str = "full",
1694
+ weighting: str = "None",
1695
+ compress: bool = False,
1696
+ displayplots: bool = False,
1697
+ ) -> NDArray[np.floating[Any]]:
1698
+ """
1699
+ Convolve two N-dimensional arrays using FFT with optional weighting.
1009
1700
 
1010
1701
  Convolve `in1` and `in2` using the fast Fourier transform method, with
1011
- the output size determined by the `mode` argument.
1012
- This is generally much faster than `convolve` for large arrays (n > ~500),
1013
- but can be slower when only a few output values are needed, and can only
1014
- output float arrays (int or object array inputs will be cast to float).
1702
+ the output size determined by the `mode` argument. This is generally much
1703
+ faster than `convolve` for large arrays (n > ~500), but can be slower when
1704
+ only a few output values are needed. The function supports both real and
1705
+ complex inputs, and allows for optional weighting and compression of the
1706
+ FFT operations.
1015
1707
 
1016
1708
  Parameters
1017
1709
  ----------
1018
- in1 : array_like
1019
- First input.
1020
- in2 : array_like
1021
- Second input. Should have the same number of dimensions as `in1`;
1710
+ in1 : NDArray[np.floating[Any]]
1711
+ First input array.
1712
+ in2 : NDArray[np.floating[Any]]
1713
+ Second input array. Should have the same number of dimensions as `in1`;
1022
1714
  if sizes of `in1` and `in2` are not equal then `in1` has to be the
1023
1715
  larger array.
1024
1716
  mode : str {'full', 'valid', 'same'}, optional
@@ -1034,19 +1726,45 @@ def convolve_weighted_fft(
1034
1726
  ``same``
1035
1727
  The output is the same size as `in1`, centered
1036
1728
  with respect to the 'full' output.
1729
+ weighting : str, optional
1730
+ Type of weighting to apply during convolution. Default is "None".
1731
+ Other options may include "uniform", "gaussian", etc., depending on
1732
+ implementation of `gccproduct`.
1733
+ compress : bool, optional
1734
+ If True, compress the FFT data during computation. Default is False.
1735
+ displayplots : bool, optional
1736
+ If True, display intermediate plots during computation. Default is False.
1037
1737
 
1038
1738
  Returns
1039
1739
  -------
1040
- out : array
1740
+ out : NDArray[np.floating[Any]]
1041
1741
  An N-dimensional array containing a subset of the discrete linear
1042
- convolution of `in1` with `in2`.
1043
- """
1044
- in1 = np.asarray(in1)
1045
- in2 = np.asarray(in2)
1742
+ convolution of `in1` with `in2`. The shape of the output depends on
1743
+ the `mode` parameter.
1046
1744
 
1047
- if np.isscalar(in1) and np.isscalar(in2): # scalar inputs
1048
- return in1 * in2
1049
- elif not in1.ndim == in2.ndim:
1745
+ Notes
1746
+ -----
1747
+ - This function uses real FFT (`rfftn`) for real inputs and standard FFT
1748
+ (`fftpack.fftn`) for complex inputs.
1749
+ - The convolution is computed in the frequency domain using the product
1750
+ of FFTs of the inputs.
1751
+ - For real inputs, the result is scaled to preserve the maximum amplitude.
1752
+ - The `gccproduct` function is used internally to compute the product
1753
+ of the FFTs with optional weighting.
1754
+
1755
+ Examples
1756
+ --------
1757
+ >>> import numpy as np
1758
+ >>> a = np.array([[1, 2], [3, 4]])
1759
+ >>> b = np.array([[1, 0], [0, 1]])
1760
+ >>> result = convolve_weighted_fft(a, b)
1761
+ >>> print(result)
1762
+ [[1. 2.]
1763
+ [3. 4.]]
1764
+ """
1765
+ # if np.isscalar(in1) and np.isscalar(in2): # scalar inputs
1766
+ # return in1 * in2
1767
+ if not in1.ndim == in2.ndim:
1050
1768
  raise ValueError("in1 and in2 should have the same rank")
1051
1769
  elif in1.size == 0 or in2.size == 0: # empty arrays
1052
1770
  return np.array([])
@@ -1090,27 +1808,73 @@ def convolve_weighted_fft(
1090
1808
  # scale to preserve the maximum
1091
1809
 
1092
1810
  if mode == "full":
1093
- return ret
1811
+ retval = ret
1094
1812
  elif mode == "same":
1095
- return _centered(ret, s1)
1813
+ retval = _centered(ret, s1)
1096
1814
  elif mode == "valid":
1097
- return _centered(ret, s1 - s2 + 1)
1815
+ retval = _centered(ret, s1 - s2 + 1)
1816
+
1817
+ return retval
1818
+
1098
1819
 
1820
+ def gccproduct(
1821
+ fft1: NDArray,
1822
+ fft2: NDArray,
1823
+ weighting: str,
1824
+ threshfrac: float = 0.1,
1825
+ compress: bool = False,
1826
+ displayplots: bool = False,
1827
+ ) -> NDArray:
1828
+ """
1829
+ Compute the generalized cross-correlation (GCC) product with optional weighting.
1099
1830
 
1100
- def gccproduct(fft1, fft2, weighting, threshfrac=0.1, compress=False, displayplots=False):
1101
- """Calculate product for generalized crosscorrelation.
1831
+ This function computes the GCC product of two FFT arrays, applying a specified
1832
+ weighting scheme to enhance correlation performance. It supports several weighting
1833
+ methods including 'liang', 'eckart', 'phat', and 'regressor'. The result can be
1834
+ thresholded and optionally compressed to improve visualization and reduce noise.
1102
1835
 
1103
1836
  Parameters
1104
1837
  ----------
1105
- fft1
1106
- fft2
1107
- weighting
1108
- threshfrac
1109
- displayplots
1838
+ fft1 : NDArray
1839
+ First FFT array (complex-valued).
1840
+ fft2 : NDArray
1841
+ Second FFT array (complex-valued).
1842
+ weighting : str
1843
+ Weighting method to apply. Options are:
1844
+ - 'liang': Liang weighting
1845
+ - 'eckart': Eckart weighting
1846
+ - 'phat': PHAT (Phase Transform) weighting
1847
+ - 'regressor': Regressor-based weighting (uses fft2 as reference)
1848
+ - 'None': No weighting applied.
1849
+ threshfrac : float, optional
1850
+ Threshold fraction used to determine the minimum value for output masking.
1851
+ Default is 0.1.
1852
+ compress : bool, optional
1853
+ If True, compress the weighting function using 10th and 90th percentiles.
1854
+ Default is False.
1855
+ displayplots : bool, optional
1856
+ If True, display the reciprocal weighting function as a plot.
1857
+ Default is False.
1110
1858
 
1111
1859
  Returns
1112
1860
  -------
1113
- product
1861
+ NDArray
1862
+ The weighted GCC product. The output is of the same shape as the input arrays.
1863
+ If `weighting` is 'None', the raw product is returned.
1864
+ If `threshfrac` is 0, a zero array of the same shape is returned.
1865
+
1866
+ Notes
1867
+ -----
1868
+ The weighting functions are applied element-wise and are designed to suppress
1869
+ noise and enhance correlation peaks. The 'phat' weighting is commonly used in
1870
+ speech and signal processing due to its robustness.
1871
+
1872
+ Examples
1873
+ --------
1874
+ >>> import numpy as np
1875
+ >>> fft1 = np.random.rand(100) + 1j * np.random.rand(100)
1876
+ >>> fft2 = np.random.rand(100) + 1j * np.random.rand(100)
1877
+ >>> result = gccproduct(fft1, fft2, weighting='phat', threshfrac=0.05)
1114
1878
  """
1115
1879
  product = fft1 * fft2
1116
1880
  if weighting == "None":
@@ -1159,17 +1923,75 @@ def gccproduct(fft1, fft2, weighting, threshfrac=0.1, compress=False, displayplo
1159
1923
 
1160
1924
 
1161
1925
  def aligntcwithref(
1162
- fixedtc,
1163
- movingtc,
1164
- Fs,
1165
- lagmin=-30,
1166
- lagmax=30,
1167
- refine=True,
1168
- zerooutbadfit=False,
1169
- widthmax=1000.0,
1170
- display=False,
1171
- verbose=False,
1172
- ):
1926
+ fixedtc: NDArray,
1927
+ movingtc: NDArray,
1928
+ Fs: float,
1929
+ lagmin: float = -30,
1930
+ lagmax: float = 30,
1931
+ refine: bool = True,
1932
+ zerooutbadfit: bool = False,
1933
+ widthmax: float = 1000.0,
1934
+ display: bool = False,
1935
+ verbose: bool = False,
1936
+ ) -> Tuple[NDArray, float, float, int]:
1937
+ """
1938
+ Align a moving timecourse to a fixed reference timecourse using cross-correlation.
1939
+
1940
+ This function computes the cross-correlation between two timecourses and finds the
1941
+ optimal time lag that maximizes their similarity. The moving timecourse is then
1942
+ aligned to the fixed one using this lag.
1943
+
1944
+ Parameters
1945
+ ----------
1946
+ fixedtc : ndarray
1947
+ The reference timecourse to which the moving timecourse will be aligned.
1948
+ movingtc : ndarray
1949
+ The timecourse to be aligned to the fixed timecourse.
1950
+ Fs : float
1951
+ Sampling frequency of the timecourses in Hz.
1952
+ lagmin : float, optional
1953
+ Minimum lag to consider in seconds. Default is -30.
1954
+ lagmax : float, optional
1955
+ Maximum lag to consider in seconds. Default is 30.
1956
+ refine : bool, optional
1957
+ If True, refine the lag estimate using Gaussian fitting. Default is True.
1958
+ zerooutbadfit : bool, optional
1959
+ If True, zero out the cross-correlation values for bad fits. Default is False.
1960
+ widthmax : float, optional
1961
+ Maximum allowed width of the Gaussian fit in samples. Default is 1000.0.
1962
+ display : bool, optional
1963
+ If True, display plots of the cross-correlation and aligned timecourses. Default is False.
1964
+ verbose : bool, optional
1965
+ If True, print detailed information about the cross-correlation results. Default is False.
1966
+
1967
+ Returns
1968
+ -------
1969
+ tuple
1970
+ A tuple containing:
1971
+ - aligneddata : ndarray
1972
+ The moving timecourse aligned to the fixed timecourse.
1973
+ - maxdelay : float
1974
+ The estimated time lag (in seconds) that maximizes cross-correlation.
1975
+ - maxval : float
1976
+ The maximum cross-correlation value.
1977
+ - failreason : int
1978
+ Reason for failure (0 = success, other values indicate specific failure types).
1979
+
1980
+ Notes
1981
+ -----
1982
+ This function uses `fastcorrelate` for efficient cross-correlation computation and
1983
+ `tide_fit.findmaxlag_gauss` to estimate the optimal lag with optional Gaussian refinement.
1984
+ The alignment is performed using `tide_resample.doresample`.
1985
+
1986
+ Examples
1987
+ --------
1988
+ >>> import numpy as np
1989
+ >>> from typing import Tuple
1990
+ >>> fixed = np.random.rand(1000)
1991
+ >>> moving = np.roll(fixed, 10) # shift by 10 samples
1992
+ >>> aligned, delay, corr, fail = aligntcwithref(fixed, moving, Fs=100)
1993
+ >>> print(f"Estimated delay: {delay}s")
1994
+ """
1173
1995
  # now fixedtc and 2 are on the same timescales
1174
1996
  thexcorr = fastcorrelate(tide_math.corrnormalize(fixedtc), tide_math.corrnormalize(movingtc))
1175
1997
  xcorrlen = len(thexcorr)