rapidtide 3.0.10__py3-none-any.whl → 3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. rapidtide/Colortables.py +492 -27
  2. rapidtide/OrthoImageItem.py +1053 -47
  3. rapidtide/RapidtideDataset.py +1533 -86
  4. rapidtide/_version.py +3 -3
  5. rapidtide/calccoherence.py +196 -29
  6. rapidtide/calcnullsimfunc.py +191 -40
  7. rapidtide/calcsimfunc.py +245 -42
  8. rapidtide/correlate.py +1210 -393
  9. rapidtide/data/examples/src/testLD +56 -0
  10. rapidtide/data/examples/src/testalign +1 -1
  11. rapidtide/data/examples/src/testdelayvar +0 -1
  12. rapidtide/data/examples/src/testfmri +19 -1
  13. rapidtide/data/examples/src/testglmfilt +5 -5
  14. rapidtide/data/examples/src/testhappy +30 -1
  15. rapidtide/data/examples/src/testppgproc +17 -0
  16. rapidtide/data/examples/src/testrolloff +11 -0
  17. rapidtide/data/models/model_cnn_pytorch/best_model.pth +0 -0
  18. rapidtide/data/models/model_cnn_pytorch/loss.png +0 -0
  19. rapidtide/data/models/model_cnn_pytorch/loss.txt +1 -0
  20. rapidtide/data/models/model_cnn_pytorch/model.pth +0 -0
  21. rapidtide/data/models/model_cnn_pytorch/model_meta.json +68 -0
  22. rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1_space-MNI152NLin2009cAsym_2mm.nii.gz +0 -0
  23. rapidtide/data/reference/JHU-ArterialTerritoriesNoVent-LVL1_space-MNI152NLin2009cAsym_2mm_mask.nii.gz +0 -0
  24. rapidtide/decorators.py +91 -0
  25. rapidtide/dlfilter.py +2225 -108
  26. rapidtide/dlfiltertorch.py +4843 -0
  27. rapidtide/externaltools.py +327 -12
  28. rapidtide/fMRIData_class.py +79 -40
  29. rapidtide/filter.py +1899 -810
  30. rapidtide/fit.py +2004 -574
  31. rapidtide/genericmultiproc.py +93 -18
  32. rapidtide/happy_supportfuncs.py +2044 -171
  33. rapidtide/helper_classes.py +584 -43
  34. rapidtide/io.py +2363 -370
  35. rapidtide/linfitfiltpass.py +341 -75
  36. rapidtide/makelaggedtcs.py +211 -20
  37. rapidtide/maskutil.py +423 -53
  38. rapidtide/miscmath.py +827 -121
  39. rapidtide/multiproc.py +210 -22
  40. rapidtide/patchmatch.py +234 -33
  41. rapidtide/peakeval.py +32 -30
  42. rapidtide/ppgproc.py +2203 -0
  43. rapidtide/qualitycheck.py +352 -39
  44. rapidtide/refinedelay.py +422 -57
  45. rapidtide/refineregressor.py +498 -184
  46. rapidtide/resample.py +671 -185
  47. rapidtide/scripts/applyppgproc.py +28 -0
  48. rapidtide/simFuncClasses.py +1052 -77
  49. rapidtide/simfuncfit.py +260 -46
  50. rapidtide/stats.py +540 -238
  51. rapidtide/tests/happycomp +9 -0
  52. rapidtide/tests/test_dlfiltertorch.py +627 -0
  53. rapidtide/tests/test_findmaxlag.py +24 -8
  54. rapidtide/tests/test_fullrunhappy_v1.py +0 -2
  55. rapidtide/tests/test_fullrunhappy_v2.py +0 -2
  56. rapidtide/tests/test_fullrunhappy_v3.py +1 -0
  57. rapidtide/tests/test_fullrunhappy_v4.py +2 -2
  58. rapidtide/tests/test_fullrunrapidtide_v7.py +1 -1
  59. rapidtide/tests/test_simroundtrip.py +8 -8
  60. rapidtide/tests/utils.py +9 -8
  61. rapidtide/tidepoolTemplate.py +142 -38
  62. rapidtide/tidepoolTemplate_alt.py +165 -44
  63. rapidtide/tidepoolTemplate_big.py +189 -52
  64. rapidtide/util.py +1217 -118
  65. rapidtide/voxelData.py +684 -37
  66. rapidtide/wiener.py +19 -12
  67. rapidtide/wiener2.py +113 -7
  68. rapidtide/wiener_doc.py +255 -0
  69. rapidtide/workflows/adjustoffset.py +105 -3
  70. rapidtide/workflows/aligntcs.py +85 -2
  71. rapidtide/workflows/applydlfilter.py +87 -10
  72. rapidtide/workflows/applyppgproc.py +522 -0
  73. rapidtide/workflows/atlasaverage.py +210 -47
  74. rapidtide/workflows/atlastool.py +100 -3
  75. rapidtide/workflows/calcSimFuncMap.py +294 -64
  76. rapidtide/workflows/calctexticc.py +201 -9
  77. rapidtide/workflows/ccorrica.py +97 -4
  78. rapidtide/workflows/cleanregressor.py +168 -29
  79. rapidtide/workflows/delayvar.py +163 -10
  80. rapidtide/workflows/diffrois.py +81 -3
  81. rapidtide/workflows/endtidalproc.py +144 -4
  82. rapidtide/workflows/fdica.py +195 -15
  83. rapidtide/workflows/filtnifti.py +70 -3
  84. rapidtide/workflows/filttc.py +74 -3
  85. rapidtide/workflows/fitSimFuncMap.py +206 -48
  86. rapidtide/workflows/fixtr.py +73 -3
  87. rapidtide/workflows/gmscalc.py +113 -3
  88. rapidtide/workflows/happy.py +813 -201
  89. rapidtide/workflows/happy2std.py +144 -12
  90. rapidtide/workflows/happy_parser.py +149 -8
  91. rapidtide/workflows/histnifti.py +118 -2
  92. rapidtide/workflows/histtc.py +84 -3
  93. rapidtide/workflows/linfitfilt.py +117 -4
  94. rapidtide/workflows/localflow.py +328 -28
  95. rapidtide/workflows/mergequality.py +79 -3
  96. rapidtide/workflows/niftidecomp.py +322 -18
  97. rapidtide/workflows/niftistats.py +174 -4
  98. rapidtide/workflows/pairproc.py +88 -2
  99. rapidtide/workflows/pairwisemergenifti.py +85 -2
  100. rapidtide/workflows/parser_funcs.py +1421 -40
  101. rapidtide/workflows/physiofreq.py +137 -11
  102. rapidtide/workflows/pixelcomp.py +208 -5
  103. rapidtide/workflows/plethquality.py +103 -21
  104. rapidtide/workflows/polyfitim.py +151 -11
  105. rapidtide/workflows/proj2flow.py +75 -2
  106. rapidtide/workflows/rankimage.py +111 -4
  107. rapidtide/workflows/rapidtide.py +272 -15
  108. rapidtide/workflows/rapidtide2std.py +98 -2
  109. rapidtide/workflows/rapidtide_parser.py +109 -9
  110. rapidtide/workflows/refineDelayMap.py +143 -33
  111. rapidtide/workflows/refineRegressor.py +682 -93
  112. rapidtide/workflows/regressfrommaps.py +152 -31
  113. rapidtide/workflows/resamplenifti.py +85 -3
  114. rapidtide/workflows/resampletc.py +91 -3
  115. rapidtide/workflows/retrolagtcs.py +98 -6
  116. rapidtide/workflows/retroregress.py +165 -9
  117. rapidtide/workflows/roisummarize.py +173 -5
  118. rapidtide/workflows/runqualitycheck.py +71 -3
  119. rapidtide/workflows/showarbcorr.py +147 -4
  120. rapidtide/workflows/showhist.py +86 -2
  121. rapidtide/workflows/showstxcorr.py +160 -3
  122. rapidtide/workflows/showtc.py +159 -3
  123. rapidtide/workflows/showxcorrx.py +184 -4
  124. rapidtide/workflows/showxy.py +185 -15
  125. rapidtide/workflows/simdata.py +262 -36
  126. rapidtide/workflows/spatialfit.py +77 -2
  127. rapidtide/workflows/spatialmi.py +251 -27
  128. rapidtide/workflows/spectrogram.py +305 -32
  129. rapidtide/workflows/synthASL.py +154 -3
  130. rapidtide/workflows/tcfrom2col.py +76 -2
  131. rapidtide/workflows/tcfrom3col.py +74 -2
  132. rapidtide/workflows/tidepool.py +2972 -133
  133. rapidtide/workflows/utils.py +19 -14
  134. rapidtide/workflows/utils_doc.py +293 -0
  135. rapidtide/workflows/variabilityizer.py +116 -3
  136. {rapidtide-3.0.10.dist-info → rapidtide-3.1.dist-info}/METADATA +10 -9
  137. {rapidtide-3.0.10.dist-info → rapidtide-3.1.dist-info}/RECORD +141 -122
  138. {rapidtide-3.0.10.dist-info → rapidtide-3.1.dist-info}/entry_points.txt +1 -0
  139. {rapidtide-3.0.10.dist-info → rapidtide-3.1.dist-info}/WHEEL +0 -0
  140. {rapidtide-3.0.10.dist-info → rapidtide-3.1.dist-info}/licenses/LICENSE +0 -0
  141. {rapidtide-3.0.10.dist-info → rapidtide-3.1.dist-info}/top_level.txt +0 -0
rapidtide/correlate.py CHANGED
@@ -19,9 +19,11 @@
19
19
  """Functions for calculating correlations and similar metrics between arrays."""
20
20
  import logging
21
21
  import warnings
22
+ from typing import Any, Callable, Optional, Tuple, Union
22
23
 
23
24
  import matplotlib.pyplot as plt
24
25
  import numpy as np
26
+ from numpy.typing import NDArray
25
27
 
26
28
  with warnings.catch_warnings():
27
29
  warnings.simplefilter("ignore")
@@ -44,6 +46,7 @@ import rapidtide.miscmath as tide_math
44
46
  import rapidtide.resample as tide_resample
45
47
  import rapidtide.stats as tide_stats
46
48
  import rapidtide.util as tide_util
49
+ from rapidtide.decorators import conditionaljit
47
50
 
48
51
  if pyfftwpresent:
49
52
  fftpack = pyfftw.interfaces.scipy_fftpack
@@ -56,60 +59,62 @@ MAXLINES = 10000000
56
59
  donotbeaggressive = True
57
60
 
58
61
 
59
- # ----------------------------------------- Conditional imports ---------------------------------------
60
- try:
61
- from numba import jit
62
- except ImportError:
63
- donotusenumba = True
64
- else:
65
- donotusenumba = False
66
-
67
-
68
- def conditionaljit():
69
- """Wrap functions in jit if numba is enabled."""
70
-
71
- def resdec(f):
72
- if donotusenumba:
73
- return f
74
- return jit(f, nopython=True)
75
-
76
- return resdec
77
-
78
-
79
- def disablenumba():
80
- """Set a global variable to disable numba."""
81
- global donotusenumba
82
- donotusenumba = True
83
-
84
-
85
62
  # --------------------------- Correlation functions -------------------------------------------------
86
63
  def check_autocorrelation(
87
- corrscale,
88
- thexcorr,
89
- delta=0.05,
90
- acampthresh=0.1,
91
- aclagthresh=10.0,
92
- displayplots=False,
93
- detrendorder=1,
94
- debug=False,
95
- ):
96
- """Check for autocorrelation in an array.
64
+ corrscale: NDArray,
65
+ thexcorr: NDArray,
66
+ delta: float = 0.05,
67
+ acampthresh: float = 0.1,
68
+ aclagthresh: float = 10.0,
69
+ displayplots: bool = False,
70
+ detrendorder: int = 1,
71
+ debug: bool = False,
72
+ ) -> Tuple[Optional[float], Optional[float]]:
73
+ """
74
+ Check for autocorrelation peaks in a cross-correlation signal and fit a Gaussian to the sidelobe.
75
+
76
+ This function identifies peaks in the cross-correlation signal and, if a significant
77
+ sidelobe is detected (based on amplitude and lag thresholds), fits a Gaussian function
78
+ to estimate the sidelobe's time and amplitude.
97
79
 
98
80
  Parameters
99
81
  ----------
100
- corrscale
101
- thexcorr
102
- delta
103
- acampthresh
104
- aclagthresh
105
- displayplots
106
- windowfunc
107
- detrendorder
82
+ corrscale : NDArray
83
+ Array of time lags corresponding to the cross-correlation values.
84
+ thexcorr : NDArray
85
+ Array of cross-correlation values.
86
+ delta : float, optional
87
+ Minimum distance between peaks, default is 0.05.
88
+ acampthresh : float, optional
89
+ Amplitude threshold for detecting sidelobes, default is 0.1.
90
+ aclagthresh : float, optional
91
+ Lag threshold beyond which sidelobes are ignored, default is 10.0.
92
+ displayplots : bool, optional
93
+ If True, display the cross-correlation plot with detected peaks, default is False.
94
+ detrendorder : int, optional
95
+ Order of detrending to apply to the signal, default is 1.
96
+ debug : bool, optional
97
+ If True, print debug information, default is False.
108
98
 
109
99
  Returns
110
100
  -------
111
- sidelobetime
112
- sidelobeamp
101
+ Tuple[Optional[float], Optional[float]]
102
+ A tuple containing the estimated sidelobe time and amplitude if a valid sidelobe is found,
103
+ otherwise (None, None).
104
+
105
+ Notes
106
+ -----
107
+ - The function uses `peakdetect` to find peaks in the cross-correlation.
108
+ - A Gaussian fit is performed only if a peak is found beyond the zero-lag point and
109
+ satisfies the amplitude and lag thresholds.
110
+ - The fit is performed on a window around the detected sidelobe.
111
+
112
+ Examples
113
+ --------
114
+ >>> corrscale = np.linspace(0, 20, 100)
115
+ >>> thexcorr = np.exp(-0.5 * (corrscale - 5)**2 / 2) + 0.1 * np.random.rand(100)
116
+ >>> time, amp = check_autocorrelation(corrscale, thexcorr, delta=0.1, acampthresh=0.05)
117
+ >>> print(f"Sidelobe time: {time}, Amplitude: {amp}")
113
118
  """
114
119
  if debug:
115
120
  print("check_autocorrelation:")
@@ -176,31 +181,61 @@ def check_autocorrelation(
176
181
 
177
182
 
178
183
  def shorttermcorr_1D(
179
- data1,
180
- data2,
181
- sampletime,
182
- windowtime,
183
- samplestep=1,
184
- detrendorder=0,
185
- windowfunc="hamming",
186
- ):
187
- """Calculate short-term sliding-window correlation between two 1D arrays.
184
+ data1: NDArray,
185
+ data2: NDArray,
186
+ sampletime: float,
187
+ windowtime: float,
188
+ samplestep: int = 1,
189
+ detrendorder: int = 0,
190
+ windowfunc: str = "hamming",
191
+ ) -> Tuple[NDArray, NDArray, NDArray]:
192
+ """
193
+ Compute short-term cross-correlation between two 1D signals using sliding windows.
194
+
195
+ This function calculates the Pearson correlation coefficient between two signals
196
+ over short time windows, allowing for the analysis of time-varying correlations.
197
+ The correlation is computed for overlapping windows across the input data,
198
+ with optional detrending and windowing applied to each segment.
188
199
 
189
200
  Parameters
190
201
  ----------
191
- data1
192
- data2
193
- sampletime
194
- windowtime
195
- samplestep
196
- detrendorder
197
- windowfunc
202
+ data1 : NDArray
203
+ First input signal (1D array).
204
+ data2 : NDArray
205
+ Second input signal (1D array). Must have the same length as `data1`.
206
+ sampletime : float
207
+ Time interval between consecutive samples in seconds.
208
+ windowtime : float
209
+ Length of the sliding window in seconds.
210
+ samplestep : int, optional
211
+ Step size (in samples) between consecutive windows. Default is 1.
212
+ detrendorder : int, optional
213
+ Order of detrending to apply before correlation. 0 means no detrending.
214
+ Default is 0.
215
+ windowfunc : str, optional
216
+ Window function to apply to each segment. Default is "hamming".
198
217
 
199
218
  Returns
200
219
  -------
201
- times
202
- corrpertime
203
- ppertime
220
+ times : NDArray
221
+ Array of time values corresponding to the center of each window.
222
+ corrpertime : NDArray
223
+ Array of Pearson correlation coefficients for each window.
224
+ ppertime : NDArray
225
+ Array of p-values associated with the correlation coefficients.
226
+
227
+ Notes
228
+ -----
229
+ The function uses `tide_math.corrnormalize` for normalization and detrending
230
+ of signal segments, and `scipy.stats.pearsonr` for computing the correlation.
231
+
232
+ Examples
233
+ --------
234
+ >>> import numpy as np
235
+ >>> data1 = np.random.randn(1000)
236
+ >>> data2 = np.random.randn(1000)
237
+ >>> times, corr, pvals = shorttermcorr_1D(data1, data2, 0.1, 1.0)
238
+ >>> print(f"Correlation at time {times[0]:.2f}: {corr[0]:.3f}")
204
239
  """
205
240
  windowsize = int(windowtime // sampletime)
206
241
  halfwindow = int((windowsize + 1) // 2)
@@ -230,42 +265,85 @@ def shorttermcorr_1D(
230
265
 
231
266
 
232
267
  def shorttermcorr_2D(
233
- data1,
234
- data2,
235
- sampletime,
236
- windowtime,
237
- samplestep=1,
238
- laglimits=None,
239
- weighting="None",
240
- zeropadding=0,
241
- windowfunc="None",
242
- detrendorder=0,
243
- compress=False,
244
- displayplots=False,
245
- ):
246
- """Calculate short-term sliding-window correlation between two 2D arrays.
268
+ data1: NDArray,
269
+ data2: NDArray,
270
+ sampletime: float,
271
+ windowtime: float,
272
+ samplestep: int = 1,
273
+ laglimits: Optional[Tuple[float, float]] = None,
274
+ weighting: str = "None",
275
+ zeropadding: int = 0,
276
+ windowfunc: str = "None",
277
+ detrendorder: int = 0,
278
+ compress: bool = False,
279
+ displayplots: bool = False,
280
+ ) -> Tuple[NDArray, NDArray, NDArray, NDArray, NDArray]:
281
+ """
282
+ Compute short-term cross-correlations between two 1D signals over sliding windows.
283
+
284
+ This function computes the cross-correlation between two input signals (`data1` and `data2`)
285
+ using a sliding window approach. For each window, the cross-correlation is computed and
286
+ the peak lag and correlation coefficient are extracted. The function supports detrending,
287
+ windowing, and various correlation weighting schemes.
247
288
 
248
289
  Parameters
249
290
  ----------
250
- data1
251
- data2
252
- sampletime
253
- windowtime
254
- samplestep
255
- laglimits
256
- weighting
257
- zeropadding
258
- windowfunc
259
- detrendorder
260
- displayplots
291
+ data1 : NDArray
292
+ First input signal (1D array).
293
+ data2 : NDArray
294
+ Second input signal (1D array). Must be of the same length as `data1`.
295
+ sampletime : float
296
+ Sampling interval of the input signals in seconds.
297
+ windowtime : float
298
+ Length of the sliding window in seconds.
299
+ samplestep : int, optional
300
+ Step size (in samples) for the sliding window. Default is 1.
301
+ laglimits : Tuple[float, float], optional
302
+ Minimum and maximum lag limits (in seconds) for peak detection.
303
+ If None, defaults to ±windowtime/2.
304
+ weighting : str, optional
305
+ Type of weighting to apply during cross-correlation ('None', 'hamming', etc.).
306
+ Default is 'None'.
307
+ zeropadding : int, optional
308
+ Zero-padding factor for the FFT-based correlation. Default is 0.
309
+ windowfunc : str, optional
310
+ Type of window function to apply ('None', 'hamming', etc.). Default is 'None'.
311
+ detrendorder : int, optional
312
+ Order of detrending to apply before correlation (0 = no detrend, 1 = linear, etc.).
313
+ Default is 0.
314
+ compress : bool, optional
315
+ Whether to compress the correlation result. Default is False.
316
+ displayplots : bool, optional
317
+ Whether to display intermediate plots (e.g., correlation matrix). Default is False.
261
318
 
262
319
  Returns
263
320
  -------
264
- times
265
- xcorrpertime
266
- Rvals
267
- delayvals
268
- valid
321
+ times : NDArray
322
+ Array of time values corresponding to the center of each window.
323
+ xcorrpertime : NDArray
324
+ Array of cross-correlation functions for each window.
325
+ Rvals : NDArray
326
+ Correlation coefficients for each window.
327
+ delayvals : NDArray
328
+ Estimated time delays (lags) for each window.
329
+ valid : NDArray
330
+ Binary array indicating whether the peak detection was successful (1) or failed (0).
331
+
332
+ Notes
333
+ -----
334
+ - The function uses `fastcorrelate` for efficient cross-correlation computation.
335
+ - Peak detection is performed using `tide_fit.findmaxlag_gauss`.
336
+ - If `displayplots` is True, an image of the cross-correlations is shown.
337
+
338
+ Examples
339
+ --------
340
+ >>> import numpy as np
341
+ >>> t = np.linspace(0, 10, 1000)
342
+ >>> signal1 = np.sin(2 * np.pi * 0.5 * t)
343
+ >>> signal2 = np.sin(2 * np.pi * 0.5 * t + 0.1)
344
+ >>> times, xcorrs, Rvals, delays, valid = shorttermcorr_2D(
345
+ ... signal1, signal2, sampletime=0.01, windowtime=1.0
346
+ ... )
269
347
  """
270
348
  windowsize = int(windowtime // sampletime)
271
349
  halfwindow = int((windowsize + 1) // 2)
@@ -356,74 +434,230 @@ def shorttermcorr_2D(
356
434
  )
357
435
 
358
436
 
359
- def calc_MI(x, y, bins=50):
360
- """Calculate mutual information between two arrays.
437
+ def calc_MI(x: NDArray, y: NDArray, bins: int = 50) -> float:
438
+ """
439
+ Calculate mutual information between two arrays.
440
+
441
+ Parameters
442
+ ----------
443
+ x : array-like
444
+ First array of data points
445
+ y : array-like
446
+ Second array of data points
447
+ bins : int, optional
448
+ Number of bins to use for histogram estimation, default is 50
449
+
450
+ Returns
451
+ -------
452
+ float
453
+ Mutual information between x and y
361
454
 
362
455
  Notes
363
456
  -----
364
- From https://stackoverflow.com/questions/20491028/
365
- optimal-way-to-compute-pairwise-mutual-information-using-numpy/
457
+ This implementation uses 2D histogram estimation followed by mutual information
458
+ calculation. The method is based on the approach from:
459
+ https://stackoverflow.com/questions/20491028/optimal-way-to-compute-pairwise-mutual-information-using-numpy/
366
460
  20505476#20505476
461
+
462
+ Examples
463
+ --------
464
+ >>> import numpy as np
465
+ >>> x = np.random.randn(1000)
466
+ >>> y = x + np.random.randn(1000) * 0.5
467
+ >>> mi = calc_MI(x, y)
468
+ >>> print(f"Mutual information: {mi:.3f}")
367
469
  """
368
470
  c_xy = np.histogram2d(x, y, bins)[0]
369
471
  mi = mutual_info_score(None, None, contingency=c_xy)
370
472
  return mi
371
473
 
372
474
 
475
+ # @conditionaljit()
476
+ def mutual_info_2d_fast(
477
+ x: NDArray[np.floating[Any]],
478
+ y: NDArray[np.floating[Any]],
479
+ bins: Tuple[NDArray, NDArray],
480
+ sigma: float = 1,
481
+ normalized: bool = True,
482
+ EPS: float = 1.0e-6,
483
+ debug: bool = False,
484
+ ) -> float:
485
+ """
486
+ Compute (normalized) mutual information between two 1D variates from a joint histogram.
487
+
488
+ Parameters
489
+ ----------
490
+ x : 1D NDArray[np.floating[Any]]
491
+ First variable.
492
+ y : 1D NDArray[np.floating[Any]]
493
+ Second variable.
494
+ bins : tuple of NDArray
495
+ Bin edges for the histogram. The first element corresponds to `x` and the second to `y`.
496
+ sigma : float, optional
497
+ Sigma for Gaussian smoothing of the joint histogram. Default is 1.
498
+ normalized : bool, optional
499
+ If True, compute normalized mutual information as defined in [1]_. Default is True.
500
+ EPS : float, optional
501
+ Small constant to avoid numerical errors in logarithms. Default is 1e-6.
502
+ debug : bool, optional
503
+ If True, print intermediate values for debugging. Default is False.
504
+
505
+ Returns
506
+ -------
507
+ float
508
+ The computed mutual information (or normalized mutual information if `normalized=True`).
509
+
510
+ Notes
511
+ -----
512
+ This function computes mutual information using a 2D histogram and Gaussian smoothing.
513
+ The normalization follows the approach described in [1]_.
514
+
515
+ References
516
+ ----------
517
+ .. [1] Colin Studholme, David John Hawkes, Derek L.G. Hill (1998).
518
+ "Normalized entropy measure for multimodality image alignment".
519
+ in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
520
+
521
+ Examples
522
+ --------
523
+ >>> import numpy as np
524
+ >>> x = np.random.randn(1000)
525
+ >>> y = np.random.randn(1000)
526
+ >>> bins = (np.linspace(-3, 3, 64), np.linspace(-3, 3, 64))
527
+ >>> mi = mutual_info_2d_fast(x, y, bins)
528
+ >>> print(mi)
529
+ """
530
+ xstart = bins[0][0]
531
+ xend = bins[0][-1]
532
+ ystart = bins[1][0]
533
+ yend = bins[1][-1]
534
+ numxbins = int(len(bins[0]) - 1)
535
+ numybins = int(len(bins[1]) - 1)
536
+ cuts = (x >= xstart) & (x < xend) & (y >= ystart) & (y < yend)
537
+ c = ((x[cuts] - xstart) / (xend - xstart) * numxbins).astype(np.int_)
538
+ c += ((y[cuts] - ystart) / (yend - ystart) * numybins).astype(np.int_) * numxbins
539
+ jh = np.bincount(c, minlength=numxbins * numybins).reshape(numxbins, numybins)
540
+
541
+ return proc_MI_histogram(jh, sigma=sigma, normalized=normalized, EPS=EPS, debug=debug)
542
+
543
+
373
544
  # @conditionaljit()
374
545
  def mutual_info_2d(
375
- x, y, sigma=1, bins=(256, 256), fast=False, normalized=True, EPS=1.0e-6, debug=False
376
- ):
377
- """Compute (normalized) mutual information between two 1D variate from a joint histogram.
546
+ x: NDArray[np.floating[Any]],
547
+ y: NDArray[np.floating[Any]],
548
+ bins: Tuple[int, int],
549
+ sigma: float = 1,
550
+ normalized: bool = True,
551
+ EPS: float = 1.0e-6,
552
+ debug: bool = False,
553
+ ) -> float:
554
+ """
555
+ Compute (normalized) mutual information between two 1D variates from a joint histogram.
378
556
 
379
- Parameters
380
- ----------
381
- x : 1D array
382
- first variable
383
- y : 1D array
384
- second variable
385
- sigma : float, optional
386
- Sigma for Gaussian smoothing of the joint histogram.
387
- Default = 1.
388
- bins : tuple, optional
389
- fast : bool, optional
390
- normalized : bool
391
- If True, this will calculate the normalized mutual information from [1]_.
392
- Default = False.
393
- EPS : float, optional
394
- Default = 1.0e-6.
557
+ Parameters
558
+ ----------
559
+ x : 1D NDArray[np.floating[Any]]
560
+ First variable.
561
+ y : 1D NDArray[np.floating[Any]]
562
+ Second variable.
563
+ bins : tuple of int
564
+ Number of bins for the histogram. The first element is the number of bins for `x`
565
+ and the second for `y`.
566
+ sigma : float, optional
567
+ Sigma for Gaussian smoothing of the joint histogram. Default is 1.
568
+ normalized : bool, optional
569
+ If True, compute normalized mutual information as defined in [1]_. Default is True.
570
+ EPS : float, optional
571
+ Small constant to avoid numerical errors in logarithms. Default is 1e-6.
572
+ debug : bool, optional
573
+ If True, print intermediate values for debugging. Default is False.
395
574
 
396
- Returns
397
- -------
398
- nmi: float
399
- the computed similarity measure
575
+ Returns
576
+ -------
577
+ float
578
+ The computed mutual information (or normalized mutual information if `normalized=True`).
400
579
 
401
- Notes
402
- -----
403
- From Ionnis Pappas
404
- BBF added the precaching (fast) option
580
+ Notes
581
+ -----
582
+ This function computes mutual information using a 2D histogram and Gaussian smoothing.
583
+ The normalization follows the approach described in [1]_.
405
584
 
406
- References
407
- ----------
408
- .. [1] Colin Studholme, David John Hawkes, Derek L.G. Hill (1998).
409
- "Normalized entropy measure for multimodality image alignment".
410
- in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
585
+ References
586
+ ----------
587
+ .. [1] Colin Studholme, David John Hawkes, Derek L.G. Hill (1998).
588
+ "Normalized entropy measure for multimodality image alignment".
589
+ in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
590
+
591
+ Examples
592
+ --------
593
+ >>> import numpy as np
594
+ >>> x = np.random.randn(1000)
595
+ >>> y = np.random.randn(1000)
596
+ >>> mi = mutual_info_2d(x, y)
597
+ >>> print(mi)
598
+ """
599
+ jh, xbins, ybins = np.histogram2d(x, y, bins=bins)
600
+ if debug:
601
+ print(f"{xbins} {ybins}")
602
+
603
+ return proc_MI_histogram(jh, sigma=sigma, normalized=normalized, EPS=EPS, debug=debug)
604
+
605
+
606
+ def proc_MI_histogram(
607
+ jh: NDArray[np.floating[Any]],
608
+ sigma: float = 1,
609
+ normalized: bool = True,
610
+ EPS: float = 1.0e-6,
611
+ debug: bool = False,
612
+ ) -> float:
411
613
  """
412
- if fast:
413
- xstart = bins[0][0]
414
- xend = bins[0][-1]
415
- ystart = bins[1][0]
416
- yend = bins[1][-1]
417
- numxbins = len(bins[0]) - 1
418
- numybins = len(bins[1]) - 1
419
- cuts = (x >= xstart) & (x < xend) & (y >= ystart) & (y < yend)
420
- c = ((x[cuts] - xstart) / (xend - xstart) * numxbins).astype(np.int_)
421
- c += ((y[cuts] - ystart) / (yend - ystart) * numybins).astype(np.int_) * numxbins
422
- jh = np.bincount(c, minlength=numxbins * numybins).reshape(numxbins, numybins)
423
- else:
424
- jh, xbins, ybins = np.histogram2d(x, y, bins=bins)
425
- if debug:
426
- print(f"{xbins} {ybins}")
614
+ Compute the mutual information (MI) between two variables from a joint histogram.
615
+
616
+ This function calculates mutual information using the joint histogram of two variables,
617
+ applying Gaussian smoothing and computing entropy-based MI. It supports both normalized
618
+ and unnormalized versions of the mutual information.
619
+
620
+ Parameters
621
+ ----------
622
+ jh : ndarray of shape (m, n)
623
+ Joint histogram of two variables. Should be a 2D array of floating point values.
624
+ sigma : float, optional
625
+ Standard deviation for Gaussian smoothing of the joint histogram. Default is 1.0.
626
+ normalized : bool, optional
627
+ If True, returns normalized mutual information. If False, returns unnormalized
628
+ mutual information. Default is True.
629
+ EPS : float, optional
630
+ Small constant added to the histogram to avoid numerical issues in log computation.
631
+ Default is 1e-6.
632
+ debug : bool, optional
633
+ If True, prints intermediate values for debugging purposes. Default is False.
634
+
635
+ Returns
636
+ -------
637
+ float
638
+ The computed mutual information (MI) between the two variables. The value is
639
+ positive and indicates the amount of information shared between the variables.
640
+
641
+ Notes
642
+ -----
643
+ The function applies Gaussian smoothing to the joint histogram before computing
644
+ marginal and joint entropies. The mutual information is computed as:
645
+
646
+ .. math::
647
+ MI = \\frac{H(X) + H(Y)}{H(X,Y)} - 1
648
+
649
+ where :math:`H(X)`, :math:`H(Y)`, and :math:`H(X,Y)` are the marginal and joint entropies,
650
+ respectively. If `normalized=False`, the unnormalized MI is returned instead.
651
+
652
+ Examples
653
+ --------
654
+ >>> import numpy as np
655
+ >>> from scipy import ndimage
656
+ >>> jh = np.random.rand(10, 10)
657
+ >>> mi = proc_MI_histogram(jh, sigma=0.5, normalized=True)
658
+ >>> print(mi)
659
+ 0.123456789
660
+ """
427
661
 
428
662
  # smooth the jh with a gaussian filter of given sigma
429
663
  sp.ndimage.gaussian_filter(jh, sigma=sigma, mode="constant", output=jh)
@@ -458,65 +692,96 @@ def mutual_info_2d(
458
692
 
459
693
  # @conditionaljit
460
694
  def cross_mutual_info(
461
- x,
462
- y,
463
- returnaxis=False,
464
- negsteps=-1,
465
- possteps=-1,
466
- locs=None,
467
- Fs=1.0,
468
- norm=True,
469
- madnorm=False,
470
- windowfunc="None",
471
- bins=-1,
472
- prebin=True,
473
- sigma=0.25,
474
- fast=True,
475
- ):
476
- """Calculate cross-mutual information between two 1D arrays.
695
+ x: NDArray[np.floating[Any]],
696
+ y: NDArray[np.floating[Any]],
697
+ returnaxis: bool = False,
698
+ negsteps: int = -1,
699
+ possteps: int = -1,
700
+ locs: Optional[NDArray] = None,
701
+ Fs: float = 1.0,
702
+ norm: bool = True,
703
+ madnorm: bool = False,
704
+ windowfunc: str = "None",
705
+ bins: int = -1,
706
+ prebin: bool = True,
707
+ sigma: float = 0.25,
708
+ fast: bool = True,
709
+ ) -> Union[NDArray, Tuple[NDArray, NDArray, int]]:
710
+ """
711
+ Calculate cross-mutual information between two 1D arrays.
712
+
713
+ This function computes the cross-mutual information (MI) between two signals
714
+ `x` and `y` at various time lags or specified offsets. It supports normalization,
715
+ windowing, and histogram smoothing for robust estimation.
716
+
477
717
  Parameters
478
718
  ----------
479
- x : 1D array
480
- first variable
481
- y : 1D array
482
- second variable. The length of y must by >= the length of x
483
- returnaxis : bool
484
- set to True to return the time axis
485
- negsteps: int
486
- possteps: int
487
- locs : list
488
- a set of offsets at which to calculate the cross mutual information
489
- Fs=1.0,
490
- norm : bool
491
- calculate normalized MI at each offset
492
- madnorm : bool
493
- set to True to normalize cross MI waveform by it's median average deviate
494
- windowfunc : str
495
- name of the window function to apply to input vectors prior to MI calculation
496
- bins : int
497
- number of bins in each dimension of the 2D histogram. Set to -1 to set automatically
498
- prebin : bool
499
- set to true to cache 2D histogram for all offsets
500
- sigma : float
501
- histogram smoothing kernel
502
- fast: bool
503
- apply speed optimizations
719
+ x : NDArray[np.floating[Any]]
720
+ First variable (signal).
721
+ y : NDArray[np.floating[Any]]
722
+ Second variable (signal). Must have length >= length of `x`.
723
+ returnaxis : bool, optional
724
+ If True, return the time axis along with the MI values. Default is False.
725
+ negsteps : int, optional
726
+ Number of negative time steps to compute MI for. If -1, uses default based on signal length.
727
+ Default is -1.
728
+ possteps : int, optional
729
+ Number of positive time steps to compute MI for. If -1, uses default based on signal length.
730
+ Default is -1.
731
+ locs : ndarray of int, optional
732
+ Specific time offsets at which to compute MI. If None, uses `negsteps` and `possteps`.
733
+ Default is None.
734
+ Fs : float, optional
735
+ Sampling frequency. Used when `returnaxis` is True. Default is 1.0.
736
+ norm : bool, optional
737
+ If True, normalize the MI values. Default is True.
738
+ madnorm : bool, optional
739
+ If True, normalize the MI waveform by its median absolute deviation (MAD).
740
+ Default is False.
741
+ windowfunc : str, optional
742
+ Name of the window function to apply to input signals before MI calculation.
743
+ Default is "None".
744
+ bins : int, optional
745
+ Number of bins for the 2D histogram. If -1, automatically determined.
746
+ Default is -1.
747
+ prebin : bool, optional
748
+ If True, precompute and cache the 2D histogram for all offsets.
749
+ Default is True.
750
+ sigma : float, optional
751
+ Standard deviation of the Gaussian smoothing kernel applied to the histogram.
752
+ Default is 0.25.
753
+ fast : bool, optional
754
+ If True, apply speed optimizations. Default is True.
504
755
 
505
756
  Returns
506
757
  -------
507
- if returnaxis is True:
508
- thexmi_x : 1D array
509
- the set of offsets at which cross mutual information is calculated
510
- thexmi_y : 1D array
511
- the set of cross mutual information values
512
- len(thexmi_x): int
513
- the number of cross mutual information values returned
514
- else:
515
- thexmi_y : 1D array
516
- the set of cross mutual information values
758
+ ndarray or tuple of ndarray
759
+ If `returnaxis` is False:
760
+ The set of cross-mutual information values.
761
+ If `returnaxis` is True:
762
+ Tuple of (time_axis, mi_values, num_values), where:
763
+ - time_axis : ndarray of float
764
+ Time axis corresponding to the MI values.
765
+ - mi_values : ndarray of float
766
+ Cross-mutual information values.
767
+ - num_values : int
768
+ Number of MI values returned.
517
769
 
770
+ Notes
771
+ -----
772
+ - The function normalizes input signals using detrending and optional windowing.
773
+ - Cross-mutual information is computed using 2D histogram estimation and
774
+ mutual information calculation.
775
+ - If `prebin` is True, the 2D histogram is precomputed for efficiency.
776
+
777
+ Examples
778
+ --------
779
+ >>> import numpy as np
780
+ >>> x = np.random.randn(100)
781
+ >>> y = np.random.randn(100)
782
+ >>> mi = cross_mutual_info(x, y)
783
+ >>> mi_axis, mi_vals, num = cross_mutual_info(x, y, returnaxis=True, Fs=10)
518
784
  """
519
-
520
785
  normx = tide_math.corrnormalize(x, detrendorder=1, windowfunc=windowfunc)
521
786
  normy = tide_math.corrnormalize(y, detrendorder=1, windowfunc=windowfunc)
522
787
 
@@ -555,32 +820,56 @@ def cross_mutual_info(
555
820
  else:
556
821
  destloc += 1
557
822
  if i < 0:
558
- thexmi_y[destloc] = mutual_info_2d(
559
- normx[: i + len(normy)],
560
- normy[-i:],
561
- bins=bins2d,
562
- normalized=norm,
563
- fast=fast,
564
- sigma=sigma,
565
- )
823
+ if fast:
824
+ thexmi_y[destloc] = mutual_info_2d_fast(
825
+ normx[: i + len(normy)],
826
+ normy[-i:],
827
+ bins2d,
828
+ normalized=norm,
829
+ sigma=sigma,
830
+ )
831
+ else:
832
+ thexmi_y[destloc] = mutual_info_2d(
833
+ normx[: i + len(normy)],
834
+ normy[-i:],
835
+ bins2d,
836
+ normalized=norm,
837
+ sigma=sigma,
838
+ )
566
839
  elif i == 0:
567
- thexmi_y[destloc] = mutual_info_2d(
568
- normx,
569
- normy,
570
- bins=bins2d,
571
- normalized=norm,
572
- fast=fast,
573
- sigma=sigma,
574
- )
840
+ if fast:
841
+ thexmi_y[destloc] = mutual_info_2d_fast(
842
+ normx,
843
+ normy,
844
+ bins2d,
845
+ normalized=norm,
846
+ sigma=sigma,
847
+ )
848
+ else:
849
+ thexmi_y[destloc] = mutual_info_2d(
850
+ normx,
851
+ normy,
852
+ bins2d,
853
+ normalized=norm,
854
+ sigma=sigma,
855
+ )
575
856
  else:
576
- thexmi_y[destloc] = mutual_info_2d(
577
- normx[i:],
578
- normy[: len(normy) - i],
579
- bins=bins2d,
580
- normalized=norm,
581
- fast=fast,
582
- sigma=sigma,
583
- )
857
+ if fast:
858
+ thexmi_y[destloc] = mutual_info_2d_fast(
859
+ normx[i:],
860
+ normy[: len(normy) - i],
861
+ bins2d,
862
+ normalized=norm,
863
+ sigma=sigma,
864
+ )
865
+ else:
866
+ thexmi_y[destloc] = mutual_info_2d(
867
+ normx[i:],
868
+ normy[: len(normy) - i],
869
+ bins2d,
870
+ normalized=norm,
871
+ sigma=sigma,
872
+ )
584
873
 
585
874
  if madnorm:
586
875
  thexmi_y = tide_math.madnormalize(thexmi_y)
@@ -599,77 +888,92 @@ def cross_mutual_info(
599
888
  return thexmi_y
600
889
 
601
890
 
602
- def mutual_info_to_r(themi, d=1):
603
- """Convert mutual information to Pearson product-moment correlation."""
604
- return np.power(1.0 - np.exp(-2.0 * themi / d), -0.5)
605
-
606
-
607
- def dtw_distance(s1, s2):
608
- # Dynamic time warping function written by GPT-4
609
- # Get the lengths of the two input sequences
610
- n, m = len(s1), len(s2)
891
+ def mutual_info_to_r(themi: float, d: int = 1) -> float:
892
+ """
893
+ Convert mutual information to Pearson product-moment correlation.
611
894
 
612
- # Initialize a (n+1) x (m+1) matrix with zeros
613
- DTW = np.zeros((n + 1, m + 1))
895
+ This function transforms mutual information values into Pearson correlation coefficients
896
+ using the relationship derived from the assumption of joint Gaussian distributions.
614
897
 
615
- # Set the first row and first column of the matrix to infinity, since
616
- # the first element of each sequence cannot be aligned with an empty sequence
617
- DTW[1:, 0] = np.inf
618
- DTW[0, 1:] = np.inf
898
+ Parameters
899
+ ----------
900
+ themi : float
901
+ Mutual information value (in nats) to be converted.
902
+ d : int, default=1
903
+ Dimensionality of the random variables. For single-dimensional variables, d=1.
904
+ For multi-dimensional variables, d represents the number of dimensions.
619
905
 
620
- # Compute the DTW distance by iteratively filling in the matrix
621
- for i in range(1, n + 1):
622
- for j in range(1, m + 1):
623
- # Compute the cost of aligning the i-th element of s1 with the j-th element of s2
624
- cost = abs(s1[i - 1] - s2[j - 1])
906
+ Returns
907
+ -------
908
+ float
909
+ Pearson product-moment correlation coefficient corresponding to the input
910
+ mutual information value. The result is in the range [0, 1].
625
911
 
626
- # Compute the minimum cost of aligning the first i-1 elements of s1 with the first j elements of s2,
627
- # the first i elements of s1 with the first j-1 elements of s2, and the first i-1 elements of s1
628
- # with the first j-1 elements of s2, and add this to the cost of aligning the i-th element of s1
629
- # with the j-th element of s2
630
- DTW[i, j] = cost + np.min([DTW[i - 1, j], DTW[i, j - 1], DTW[i - 1, j - 1]])
912
+ Notes
913
+ -----
914
+ The transformation is based on the formula:
915
+ r = (1 - exp(-2*MI/d))^(-1/2)
631
916
 
632
- # Return the DTW distance between the two sequences, which is the value in the last cell of the matrix
633
- return DTW[n, m]
917
+ This approximation is valid under the assumption that the variables follow
918
+ a joint Gaussian distribution. For non-Gaussian distributions, the relationship
919
+ may not hold exactly.
634
920
 
921
+ Examples
922
+ --------
923
+ >>> mutual_info_to_r(1.0)
924
+ 0.8416445342422313
635
925
 
636
- def delayedcorr(data1, data2, delayval, timestep):
637
- """Calculate correlation between two 1D arrays, at specific delay.
926
+ >>> mutual_info_to_r(2.0, d=2)
927
+ 0.9640275800758169
928
+ """
929
+ return np.power(1.0 - np.exp(-2.0 * themi / d), -0.5)
638
930
 
639
- Parameters
640
- ----------
641
- data1
642
- data2
643
- delayval
644
- timestep
645
931
 
646
- Returns
647
- -------
648
- corr
649
- """
932
+ def delayedcorr(
933
+ data1: NDArray, data2: NDArray, delayval: float, timestep: float
934
+ ) -> Tuple[float, float]:
650
935
  return sp.stats.pearsonr(data1, tide_resample.timeshift(data2, delayval / timestep, 30)[0])
651
936
 
652
937
 
653
- def cepstraldelay(data1, data2, timestep, displayplots=True):
938
+ def cepstraldelay(
939
+ data1: NDArray, data2: NDArray, timestep: float, displayplots: bool = True
940
+ ) -> float:
654
941
  """
655
- Estimate delay between two signals using Choudhary's cepstral analysis method.
942
+ Calculate correlation between two datasets with a time delay applied to the second dataset.
943
+
944
+ This function computes the Pearson correlation coefficient between two datasets,
945
+ where the second dataset is time-shifted by a specified delay before correlation
946
+ is calculated. The time shift is applied using the tide_resample.timeshift function.
656
947
 
657
948
  Parameters
658
949
  ----------
659
- data1
660
- data2
661
- timestep
662
- displayplots
950
+ data1 : NDArray
951
+ First dataset for correlation calculation.
952
+ data2 : NDArray
953
+ Second dataset to be time-shifted and correlated with data1.
954
+ delayval : float
955
+ Time delay to apply to data2, specified in the same units as timestep.
956
+ timestep : float
957
+ Time step of the datasets, used to convert delayval to sample units.
663
958
 
664
959
  Returns
665
960
  -------
666
- arr
961
+ Tuple[float, float]
962
+ Pearson correlation coefficient and p-value from the correlation test.
667
963
 
668
- References
669
- ----------
670
- * Choudhary, H., Bahl, R. & Kumar, A.
671
- Inter-sensor Time Delay Estimation using cepstrum of sum and difference signals in
672
- underwater multipath environment. in 1-7 (IEEE, 2015). doi:10.1109/UT.2015.7108308
964
+ Notes
965
+ -----
966
+ The delayval is converted to sample units by dividing by timestep before
967
+ applying the time shift. The tide_resample.timeshift function is used internally
968
+ with a window parameter of 30.
969
+
970
+ Examples
971
+ --------
972
+ >>> import numpy as np
973
+ >>> data1 = np.array([1, 2, 3, 4, 5])
974
+ >>> data2 = np.array([2, 3, 4, 5, 6])
975
+ >>> corr, p_value = delayedcorr(data1, data2, delay=1.0, timestep=0.1)
976
+ >>> print(f"Correlation: {corr:.3f}")
673
977
  """
674
978
  ceps1, _ = tide_math.complex_cepstrum(data1)
675
979
  ceps2, _ = tide_math.complex_cepstrum(data2)
@@ -707,24 +1011,74 @@ def cepstraldelay(data1, data2, timestep, displayplots=True):
707
1011
 
708
1012
 
709
1013
  class AliasedCorrelator:
710
- """An aliased correlator.
1014
+ def __init__(self, hiressignal, hires_Fs, numsteps):
1015
+ """
1016
+ Initialize the object with high-resolution signal parameters.
711
1017
 
712
- Parameters
713
- ----------
714
- hiressignal : 1D array
715
- The unaliased waveform to match
716
- hires_Fs : float
717
- The sample rate of the unaliased waveform
718
- numsteps : int
719
- Number of distinct slice acquisition times within the TR.
720
- """
1018
+ Parameters
1019
+ ----------
1020
+ hiressignal : array-like
1021
+ High-resolution signal data to be processed.
1022
+ hires_Fs : float
1023
+ Sampling frequency of the high-resolution signal in Hz.
1024
+ numsteps : int
1025
+ Number of steps for signal processing.
721
1026
 
722
- def __init__(self, hiressignal, hires_Fs, numsteps):
1027
+ Returns
1028
+ -------
1029
+ None
1030
+ This method initializes the object attributes and does not return any value.
1031
+
1032
+ Notes
1033
+ -----
1034
+ This constructor sets up the basic configuration for high-resolution signal processing
1035
+ by storing the sampling frequency and number of steps, then calls sethiressignal()
1036
+ to process the input signal.
1037
+
1038
+ Examples
1039
+ --------
1040
+ >>> obj = MyClass(hiressignal, hires_Fs=44100, numsteps=100)
1041
+ >>> obj.hires_Fs
1042
+ 44100
1043
+ >>> obj.numsteps
1044
+ 100
1045
+ """
723
1046
  self.hires_Fs = hires_Fs
724
1047
  self.numsteps = numsteps
725
1048
  self.sethiressignal(hiressignal)
726
1049
 
727
1050
  def sethiressignal(self, hiressignal):
1051
+ """
1052
+ Set high resolution signal and compute related parameters.
1053
+
1054
+ This method processes the high resolution signal by normalizing it and computing
1055
+ correlation-related parameters including correlation length and correlation x-axis.
1056
+
1057
+ Parameters
1058
+ ----------
1059
+ hiressignal : array-like
1060
+ High resolution signal data to be processed and normalized.
1061
+
1062
+ Returns
1063
+ -------
1064
+ None
1065
+ This method modifies the instance attributes in-place and does not return a value.
1066
+
1067
+ Notes
1068
+ -----
1069
+ The method performs correlation normalization using `tide_math.corrnormalize` and
1070
+ computes the correlation length as `len(self.hiressignal) * 2 + 1`. The correlation
1071
+ x-axis is computed based on the sampling frequency (`self.hires_Fs`) and the length
1072
+ of the high resolution signal.
1073
+
1074
+ Examples
1075
+ --------
1076
+ >>> obj.sethiressignal(hiressignal_data)
1077
+ >>> print(obj.corrlen)
1078
+ 1001
1079
+ >>> print(obj.corrx.shape)
1080
+ (1001,)
1081
+ """
728
1082
  self.hiressignal = tide_math.corrnormalize(hiressignal)
729
1083
  self.corrlen = len(self.hiressignal) * 2 + 1
730
1084
  self.corrx = (
@@ -733,23 +1087,63 @@ class AliasedCorrelator:
733
1087
  )
734
1088
 
735
1089
  def getxaxis(self):
1090
+ """
1091
+ Return the x-axis correction value.
1092
+
1093
+ This method retrieves the correction value applied to the x-axis.
1094
+
1095
+ Returns
1096
+ -------
1097
+ float or int
1098
+ The correction value for the x-axis stored in `self.corrx`.
1099
+
1100
+ Notes
1101
+ -----
1102
+ The returned value represents the x-axis correction that has been
1103
+ previously computed or set in the object's `corrx` attribute.
1104
+
1105
+ Examples
1106
+ --------
1107
+ >>> obj = MyClass()
1108
+ >>> obj.corrx = 5.0
1109
+ >>> obj.getxaxis()
1110
+ 5.0
1111
+ """
736
1112
  return self.corrx
737
1113
 
738
1114
  def apply(self, loressignal, offset, debug=False):
739
- """Apply correlator to aliased waveform.
1115
+ """
1116
+ Apply correlator to aliased waveform.
1117
+
740
1118
  NB: Assumes the highres frequency is an integral multiple of the lowres frequency
1119
+
741
1120
  Parameters
742
1121
  ----------
743
- loressignal: 1D array
1122
+ loressignal : 1D array
744
1123
  The aliased waveform to match
745
- offset: int
1124
+ offset : int
746
1125
  Integer offset to apply to the upsampled lowressignal (to account for slice time offset)
747
- debug: bool, optional
1126
+ debug : bool, optional
748
1127
  Whether to print diagnostic information
1128
+
749
1129
  Returns
750
1130
  -------
751
- corrfunc: 1D array
1131
+ corrfunc : 1D array
752
1132
  The full correlation function
1133
+
1134
+ Notes
1135
+ -----
1136
+ This function applies a correlator to an aliased waveform by:
1137
+ 1. Creating an upsampled version of the high-resolution signal
1138
+ 2. Inserting the low-resolution signal at the specified offset
1139
+ 3. Computing the cross-correlation between the two signals
1140
+ 4. Normalizing the result by the square root of the number of steps
1141
+
1142
+ Examples
1143
+ --------
1144
+ >>> result = correlator.apply(signal, offset=5, debug=True)
1145
+ >>> print(result.shape)
1146
+ (len(highres_signal),)
753
1147
  """
754
1148
  if debug:
755
1149
  print(offset, self.numsteps)
@@ -762,13 +1156,63 @@ class AliasedCorrelator:
762
1156
 
763
1157
 
764
1158
  def matchsamplerates(
765
- input1,
766
- Fs1,
767
- input2,
768
- Fs2,
769
- method="univariate",
770
- debug=False,
771
- ):
1159
+ input1: NDArray,
1160
+ Fs1: float,
1161
+ input2: NDArray,
1162
+ Fs2: float,
1163
+ method: str = "univariate",
1164
+ debug: bool = False,
1165
+ ) -> Tuple[NDArray, NDArray, float]:
1166
+ """
1167
+ Match sampling rates of two input arrays by upsampling the lower sampling rate signal.
1168
+
1169
+ This function takes two input arrays with potentially different sampling rates and
1170
+ ensures they have the same sampling rate by upsampling the signal with the lower
1171
+ sampling rate to match the higher one. The function preserves the original data
1172
+ while adjusting the sampling rate for compatibility.
1173
+
1174
+ Parameters
1175
+ ----------
1176
+ input1 : NDArray
1177
+ First input array to be processed.
1178
+ Fs1 : float
1179
+ Sampling frequency of the first input array (Hz).
1180
+ input2 : NDArray
1181
+ Second input array to be processed.
1182
+ Fs2 : float
1183
+ Sampling frequency of the second input array (Hz).
1184
+ method : str, optional
1185
+ Resampling method to use, by default "univariate".
1186
+ See `tide_resample.upsample` for available methods.
1187
+ debug : bool, optional
1188
+ Enable debug output, by default False.
1189
+
1190
+ Returns
1191
+ -------
1192
+ Tuple[NDArray, NDArray, float]
1193
+ Tuple containing:
1194
+ - matchedinput1: First input array upsampled to match the sampling rate
1195
+ - matchedinput2: Second input array upsampled to match the sampling rate
1196
+ - corrFs: The common sampling frequency used for both outputs
1197
+
1198
+ Notes
1199
+ -----
1200
+ - If sampling rates are equal, no upsampling is performed
1201
+ - The function always upsamples to the higher sampling rate
1202
+ - The upsampling is performed using the `tide_resample.upsample` function
1203
+ - Both output arrays will have the same length and sampling rate
1204
+
1205
+ Examples
1206
+ --------
1207
+ >>> import numpy as np
1208
+ >>> input1 = np.array([1, 2, 3, 4])
1209
+ >>> input2 = np.array([5, 6, 7])
1210
+ >>> Fs1 = 10.0
1211
+ >>> Fs2 = 5.0
1212
+ >>> matched1, matched2, common_fs = matchsamplerates(input1, Fs1, input2, Fs2)
1213
+ >>> print(common_fs)
1214
+ 10.0
1215
+ """
772
1216
  if Fs1 > Fs2:
773
1217
  corrFs = Fs1
774
1218
  matchedinput1 = input1
@@ -785,16 +1229,74 @@ def matchsamplerates(
785
1229
 
786
1230
 
787
1231
  def arbcorr(
788
- input1,
789
- Fs1,
790
- input2,
791
- Fs2,
792
- start1=0.0,
793
- start2=0.0,
794
- windowfunc="hamming",
795
- method="univariate",
796
- debug=False,
797
- ):
1232
+ input1: NDArray,
1233
+ Fs1: float,
1234
+ input2: NDArray,
1235
+ Fs2: float,
1236
+ start1: float = 0.0,
1237
+ start2: float = 0.0,
1238
+ windowfunc: str = "hamming",
1239
+ method: str = "univariate",
1240
+ debug: bool = False,
1241
+ ) -> Tuple[NDArray, NDArray, float, int]:
1242
+ """
1243
+ Compute the cross-correlation between two signals with arbitrary sampling rates.
1244
+
1245
+ This function performs cross-correlation between two input signals after
1246
+ matching their sampling rates. It applies normalization and uses FFT-based
1247
+ convolution for efficient computation. The result includes the time lag axis,
1248
+ cross-correlation values, the matched sampling frequency, and the index of
1249
+ the zero-lag point.
1250
+
1251
+ Parameters
1252
+ ----------
1253
+ input1 : NDArray
1254
+ First input signal array.
1255
+ Fs1 : float
1256
+ Sampling frequency of the first signal (Hz).
1257
+ input2 : NDArray
1258
+ Second input signal array.
1259
+ Fs2 : float
1260
+ Sampling frequency of the second signal (Hz).
1261
+ start1 : float, optional
1262
+ Start time of the first signal (default is 0.0).
1263
+ start2 : float, optional
1264
+ Start time of the second signal (default is 0.0).
1265
+ windowfunc : str, optional
1266
+ Window function used for normalization (default is "hamming").
1267
+ method : str, optional
1268
+ Method used for matching sampling rates (default is "univariate").
1269
+ debug : bool, optional
1270
+ If True, enables debug logging (default is False).
1271
+
1272
+ Returns
1273
+ -------
1274
+ tuple
1275
+ A tuple containing:
1276
+ - thexcorr_x : NDArray
1277
+ Time lag axis for the cross-correlation (seconds).
1278
+ - thexcorr_y : NDArray
1279
+ Cross-correlation values.
1280
+ - corrFs : float
1281
+ Matched sampling frequency used for the computation (Hz).
1282
+ - zeroloc : int
1283
+ Index corresponding to the zero-lag point in the cross-correlation.
1284
+
1285
+ Notes
1286
+ -----
1287
+ - The function upsamples the signals to the higher of the two sampling rates.
1288
+ - Normalization is applied using a detrend order of 1 and the specified window function.
1289
+ - The cross-correlation is computed using FFT convolution for efficiency.
1290
+ - The zero-lag point is determined as the index of the minimum absolute value in the time axis.
1291
+
1292
+ Examples
1293
+ --------
1294
+ >>> import numpy as np
1295
+ >>> signal1 = np.random.randn(1000)
1296
+ >>> signal2 = np.random.randn(1000)
1297
+ >>> lags, corr_vals, fs, zero_idx = arbcorr(signal1, 10.0, signal2, 10.0)
1298
+ >>> print(f"Zero-lag index: {zero_idx}")
1299
+ """
798
1300
  # upsample to the higher frequency of the two
799
1301
  matchedinput1, matchedinput2, corrFs = matchsamplerates(
800
1302
  input1,
@@ -822,13 +1324,66 @@ def arbcorr(
822
1324
 
823
1325
 
824
1326
  def faststcorrelate(
825
- input1, input2, windowtype="hann", nperseg=32, weighting="None", displayplots=False
826
- ):
827
- """Perform correlation between short-time Fourier transformed arrays."""
1327
+ input1: NDArray,
1328
+ input2: NDArray,
1329
+ windowtype: str = "hann",
1330
+ nperseg: int = 32,
1331
+ weighting: str = "None",
1332
+ displayplots: bool = False,
1333
+ ) -> Tuple[NDArray, NDArray, NDArray]:
1334
+ """
1335
+ Perform correlation between short-time Fourier transformed arrays.
1336
+
1337
+ This function computes the short-time cross-correlation between two input signals
1338
+ using their short-time Fourier transforms (STFTs). It applies a windowing function
1339
+ to each signal, computes the STFT, and then performs correlation in the frequency
1340
+ domain before inverse transforming back to the time domain. The result is normalized
1341
+ by the auto-correlation of each signal.
1342
+
1343
+ Parameters
1344
+ ----------
1345
+ input1 : ndarray
1346
+ First input signal array.
1347
+ input2 : ndarray
1348
+ Second input signal array.
1349
+ windowtype : str, optional
1350
+ Type of window to apply. Default is 'hann'.
1351
+ nperseg : int, optional
1352
+ Length of each segment for STFT. Default is 32.
1353
+ weighting : str, optional
1354
+ Weighting method for the STFT. Default is 'None'.
1355
+ displayplots : bool, optional
1356
+ If True, display plots (not implemented in current version). Default is False.
1357
+
1358
+ Returns
1359
+ -------
1360
+ corrtimes : ndarray
1361
+ Time shifts corresponding to the correlation results.
1362
+ times : ndarray
1363
+ Time indices of the STFT.
1364
+ stcorr : ndarray
1365
+ Short-time cross-correlation values.
1366
+
1367
+ Notes
1368
+ -----
1369
+ The function uses `scipy.signal.stft` to compute the short-time Fourier transform
1370
+ of both input signals. The correlation is computed in the frequency domain and
1371
+ normalized by the square root of the auto-correlation of each signal.
1372
+
1373
+ Examples
1374
+ --------
1375
+ >>> import numpy as np
1376
+ >>> from scipy import signal
1377
+ >>> t = np.linspace(0, 1, 100)
1378
+ >>> x1 = np.sin(2 * np.pi * 5 * t)
1379
+ >>> x2 = np.sin(2 * np.pi * 5 * t + 0.1)
1380
+ >>> corrtimes, times, corr = faststcorrelate(x1, x2)
1381
+ >>> print(corr.shape)
1382
+ (32, 100)
1383
+ """
828
1384
  nfft = nperseg
829
1385
  noverlap = nperseg - 1
830
1386
  onesided = False
831
- boundary = "even"
832
1387
  freqs, times, thestft1 = signal.stft(
833
1388
  input1,
834
1389
  fs=1.0,
@@ -838,7 +1393,7 @@ def faststcorrelate(
838
1393
  nfft=nfft,
839
1394
  detrend="linear",
840
1395
  return_onesided=onesided,
841
- boundary=boundary,
1396
+ boundary="even",
842
1397
  padded=True,
843
1398
  axis=-1,
844
1399
  )
@@ -852,7 +1407,7 @@ def faststcorrelate(
852
1407
  nfft=nfft,
853
1408
  detrend="linear",
854
1409
  return_onesided=onesided,
855
- boundary=boundary,
1410
+ boundary="even",
856
1411
  padded=True,
857
1412
  axis=-1,
858
1413
  )
@@ -878,7 +1433,40 @@ def faststcorrelate(
878
1433
  return corrtimes, times, stcorr
879
1434
 
880
1435
 
881
- def primefacs(thelen):
1436
+ def primefacs(thelen: int) -> list:
1437
+ """
1438
+ Compute the prime factorization of a given integer.
1439
+
1440
+ Parameters
1441
+ ----------
1442
+ thelen : int
1443
+ The positive integer to factorize. Must be greater than 0.
1444
+
1445
+ Returns
1446
+ -------
1447
+ list
1448
+ A list of prime factors of `thelen`, sorted in ascending order.
1449
+ Each factor appears as many times as its multiplicity in the
1450
+ prime factorization.
1451
+
1452
+ Notes
1453
+ -----
1454
+ This function implements trial division algorithm to find prime factors.
1455
+ The algorithm starts with the smallest prime (2) and continues with
1456
+ increasing integers until the square root of the remaining number.
1457
+ The final remaining number (if greater than 1) is also a prime factor.
1458
+
1459
+ Examples
1460
+ --------
1461
+ >>> primefacs(12)
1462
+ [2, 2, 3]
1463
+
1464
+ >>> primefacs(17)
1465
+ [17]
1466
+
1467
+ >>> primefacs(100)
1468
+ [2, 2, 5, 5]
1469
+ """
882
1470
  i = 2
883
1471
  factors = []
884
1472
  while i * i <= thelen:
@@ -891,40 +1479,99 @@ def primefacs(thelen):
891
1479
  return factors
892
1480
 
893
1481
 
894
- def optfftlen(thelen):
1482
+ def optfftlen(thelen: int) -> int:
1483
+ """
1484
+ Calculate optimal FFT length for given input length.
1485
+
1486
+ This function currently returns the input length as-is, but is designed
1487
+ to be extended for optimal FFT length calculation based on hardware
1488
+ constraints or performance considerations.
1489
+
1490
+ Parameters
1491
+ ----------
1492
+ thelen : int
1493
+ The input length for which to calculate optimal FFT length.
1494
+ Must be a positive integer.
1495
+
1496
+ Returns
1497
+ -------
1498
+ int
1499
+ The optimal FFT length. For the current implementation, this
1500
+ simply returns the input `thelen` value.
1501
+
1502
+ Notes
1503
+ -----
1504
+ In a more complete implementation, this function would calculate
1505
+ the optimal FFT length by finding the smallest number >= thelen
1506
+ that has only small prime factors (2, 3, 5, 7) for optimal
1507
+ performance on most FFT implementations.
1508
+
1509
+ Examples
1510
+ --------
1511
+ >>> optfftlen(1024)
1512
+ 1024
1513
+ >>> optfftlen(1000)
1514
+ 1000
1515
+ """
895
1516
  return thelen
896
1517
 
897
1518
 
898
1519
  def fastcorrelate(
899
- input1,
900
- input2,
901
- usefft=True,
902
- zeropadding=0,
903
- weighting="None",
904
- compress=False,
905
- displayplots=False,
906
- debug=False,
907
- ):
908
- """Perform a fast correlation between two arrays.
1520
+ input1: NDArray,
1521
+ input2: NDArray,
1522
+ usefft: bool = True,
1523
+ zeropadding: int = 0,
1524
+ weighting: str = "None",
1525
+ compress: bool = False,
1526
+ displayplots: bool = False,
1527
+ debug: bool = False,
1528
+ ) -> NDArray:
1529
+ """
1530
+ Perform a fast correlation between two arrays.
1531
+
1532
+ This function computes the cross-correlation of two input arrays, with options
1533
+ for using FFT-based convolution or direct correlation, as well as padding and
1534
+ weighting schemes.
909
1535
 
910
1536
  Parameters
911
1537
  ----------
912
- input1
913
- input2
914
- usefft
915
- zeropadding
916
- weighting
917
- compress
918
- displayplots
919
- debug
1538
+ input1 : ndarray
1539
+ First input array to correlate.
1540
+ input2 : ndarray
1541
+ Second input array to correlate.
1542
+ usefft : bool, optional
1543
+ If True, use FFT-based convolution for faster computation. Default is True.
1544
+ zeropadding : int, optional
1545
+ Zero-padding length. If 0, no padding is applied. If negative, automatic
1546
+ padding is applied. If positive, explicit padding is applied. Default is 0.
1547
+ weighting : str, optional
1548
+ Type of weighting to apply. If "None", no weighting is applied. Default is "None".
1549
+ compress : bool, optional
1550
+ If True and `weighting` is not "None", compress the result. Default is False.
1551
+ displayplots : bool, optional
1552
+ If True, display plots of padded inputs and correlation result. Default is False.
1553
+ debug : bool, optional
1554
+ If True, enable debug output. Default is False.
920
1555
 
921
1556
  Returns
922
1557
  -------
923
- corr
1558
+ ndarray
1559
+ The cross-correlation of `input1` and `input2`. The length of the output is
1560
+ `len(input1) + len(input2) - 1`.
924
1561
 
925
1562
  Notes
926
1563
  -----
927
- From http://stackoverflow.com/questions/12323959/fast-cross-correlation-method-in-python.
1564
+ This implementation is based on the method described at:
1565
+ http://stackoverflow.com/questions/12323959/fast-cross-correlation-method-in-python
1566
+
1567
+ Examples
1568
+ --------
1569
+ >>> import numpy as np
1570
+ >>> a = np.array([1, 2, 3])
1571
+ >>> b = np.array([0, 1, 0])
1572
+ >>> result = fastcorrelate(a, b)
1573
+ >>> print(result)
1574
+ [0. 1. 2. 3. 0.]
928
1575
  """
929
1576
  len1 = len(input1)
930
1577
  len2 = len(input2)
@@ -995,17 +1642,36 @@ def fastcorrelate(
995
1642
  return np.correlate(paddedinput1, paddedinput2, mode="full")
996
1643
 
997
1644
 
998
- def _centered(arr, newsize):
999
- """Return the center newsize portion of the array.
1645
+ def _centered(arr: NDArray, newsize: Union[int, NDArray]) -> NDArray:
1646
+ """
1647
+ Extract a centered subset of an array.
1000
1648
 
1001
1649
  Parameters
1002
1650
  ----------
1003
- arr
1004
- newsize
1651
+ arr : array_like
1652
+ Input array from which to extract the centered subset.
1653
+ newsize : int or array_like
1654
+ The size of the output array. If int, the same size is used for all dimensions.
1655
+ If array_like, specifies the size for each dimension.
1005
1656
 
1006
1657
  Returns
1007
1658
  -------
1008
- arr
1659
+ ndarray
1660
+ Centered subset of the input array with the specified size.
1661
+
1662
+ Notes
1663
+ -----
1664
+ The function extracts a subset from the center of the input array. If the requested
1665
+ size is larger than the current array size in any dimension, the result will be
1666
+ padded with zeros (or the array will be truncated from the center).
1667
+
1668
+ Examples
1669
+ --------
1670
+ >>> import numpy as np
1671
+ >>> arr = np.arange(24).reshape(4, 6)
1672
+ >>> _centered(arr, (2, 3))
1673
+ array([[ 7, 8, 9],
1674
+ [13, 14, 15]])
1009
1675
  """
1010
1676
  newsize = np.asarray(newsize)
1011
1677
  currsize = np.array(arr.shape)
@@ -1015,22 +1681,36 @@ def _centered(arr, newsize):
1015
1681
  return arr[tuple(myslice)]
1016
1682
 
1017
1683
 
1018
- def _check_valid_mode_shapes(shape1, shape2):
1019
- """Check that two shapes are 'valid' with respect to one another.
1020
-
1021
- Specifically, this checks that each item in one tuple is larger than or
1022
- equal to corresponding item in another tuple.
1684
+ def _check_valid_mode_shapes(shape1: Tuple, shape2: Tuple) -> None:
1685
+ """
1686
+ Check that shape1 is valid for 'valid' mode convolution with shape2.
1023
1687
 
1024
1688
  Parameters
1025
1689
  ----------
1026
- shape1
1027
- shape2
1028
-
1029
- Raises
1030
- ------
1031
- ValueError
1032
- If at least one item in the first shape is not larger than or equal to
1033
- the corresponding item in the second one.
1690
+ shape1 : Tuple
1691
+ First shape tuple to compare
1692
+ shape2 : Tuple
1693
+ Second shape tuple to compare
1694
+
1695
+ Returns
1696
+ -------
1697
+ None
1698
+ This function does not return anything but raises ValueError if condition is not met
1699
+
1700
+ Notes
1701
+ -----
1702
+ This function is used to validate that the first shape has at least as many
1703
+ elements as the second shape in every dimension, which is required for
1704
+ 'valid' mode convolution operations.
1705
+
1706
+ Examples
1707
+ --------
1708
+ >>> _check_valid_mode_shapes((10, 10), (5, 5))
1709
+ >>> _check_valid_mode_shapes((10, 10), (10, 5))
1710
+ >>> _check_valid_mode_shapes((5, 5), (10, 5))
1711
+ Traceback (most recent call last):
1712
+ ...
1713
+ ValueError: in1 should have at least as many items as in2 in every dimension for 'valid' mode.
1034
1714
  """
1035
1715
  for d1, d2 in zip(shape1, shape2):
1036
1716
  if not d1 >= d2:
@@ -1041,22 +1721,29 @@ def _check_valid_mode_shapes(shape1, shape2):
1041
1721
 
1042
1722
 
1043
1723
  def convolve_weighted_fft(
1044
- in1, in2, mode="full", weighting="None", compress=False, displayplots=False
1045
- ):
1046
- """Convolve two N-dimensional arrays using FFT.
1724
+ in1: NDArray[np.floating[Any]],
1725
+ in2: NDArray[np.floating[Any]],
1726
+ mode: str = "full",
1727
+ weighting: str = "None",
1728
+ compress: bool = False,
1729
+ displayplots: bool = False,
1730
+ ) -> NDArray[np.floating[Any]]:
1731
+ """
1732
+ Convolve two N-dimensional arrays using FFT with optional weighting.
1047
1733
 
1048
1734
  Convolve `in1` and `in2` using the fast Fourier transform method, with
1049
- the output size determined by the `mode` argument.
1050
- This is generally much faster than `convolve` for large arrays (n > ~500),
1051
- but can be slower when only a few output values are needed, and can only
1052
- output float arrays (int or object array inputs will be cast to float).
1735
+ the output size determined by the `mode` argument. This is generally much
1736
+ faster than `convolve` for large arrays (n > ~500), but can be slower when
1737
+ only a few output values are needed. The function supports both real and
1738
+ complex inputs, and allows for optional weighting and compression of the
1739
+ FFT operations.
1053
1740
 
1054
1741
  Parameters
1055
1742
  ----------
1056
- in1 : array_like
1057
- First input.
1058
- in2 : array_like
1059
- Second input. Should have the same number of dimensions as `in1`;
1743
+ in1 : NDArray[np.floating[Any]]
1744
+ First input array.
1745
+ in2 : NDArray[np.floating[Any]]
1746
+ Second input array. Should have the same number of dimensions as `in1`;
1060
1747
  if sizes of `in1` and `in2` are not equal then `in1` has to be the
1061
1748
  larger array.
1062
1749
  mode : str {'full', 'valid', 'same'}, optional
@@ -1072,19 +1759,45 @@ def convolve_weighted_fft(
1072
1759
  ``same``
1073
1760
  The output is the same size as `in1`, centered
1074
1761
  with respect to the 'full' output.
1762
+ weighting : str, optional
1763
+ Type of weighting to apply during convolution. Default is "None".
1764
+ Other options may include "uniform", "gaussian", etc., depending on
1765
+ implementation of `gccproduct`.
1766
+ compress : bool, optional
1767
+ If True, compress the FFT data during computation. Default is False.
1768
+ displayplots : bool, optional
1769
+ If True, display intermediate plots during computation. Default is False.
1075
1770
 
1076
1771
  Returns
1077
1772
  -------
1078
- out : array
1773
+ out : NDArray[np.floating[Any]]
1079
1774
  An N-dimensional array containing a subset of the discrete linear
1080
- convolution of `in1` with `in2`.
1081
- """
1082
- in1 = np.asarray(in1)
1083
- in2 = np.asarray(in2)
1775
+ convolution of `in1` with `in2`. The shape of the output depends on
1776
+ the `mode` parameter.
1084
1777
 
1085
- if np.isscalar(in1) and np.isscalar(in2): # scalar inputs
1086
- return in1 * in2
1087
- elif not in1.ndim == in2.ndim:
1778
+ Notes
1779
+ -----
1780
+ - This function uses real FFT (`rfftn`) for real inputs and standard FFT
1781
+ (`fftpack.fftn`) for complex inputs.
1782
+ - The convolution is computed in the frequency domain using the product
1783
+ of FFTs of the inputs.
1784
+ - For real inputs, the result is scaled to preserve the maximum amplitude.
1785
+ - The `gccproduct` function is used internally to compute the product
1786
+ of the FFTs with optional weighting.
1787
+
1788
+ Examples
1789
+ --------
1790
+ >>> import numpy as np
1791
+ >>> a = np.array([[1, 2], [3, 4]])
1792
+ >>> b = np.array([[1, 0], [0, 1]])
1793
+ >>> result = convolve_weighted_fft(a, b)
1794
+ >>> print(result)
1795
+ [[1. 2.]
1796
+ [3. 4.]]
1797
+ """
1798
+ # if np.isscalar(in1) and np.isscalar(in2): # scalar inputs
1799
+ # return in1 * in2
1800
+ if not in1.ndim == in2.ndim:
1088
1801
  raise ValueError("in1 and in2 should have the same rank")
1089
1802
  elif in1.size == 0 or in2.size == 0: # empty arrays
1090
1803
  return np.array([])
@@ -1128,27 +1841,73 @@ def convolve_weighted_fft(
1128
1841
  # scale to preserve the maximum
1129
1842
 
1130
1843
  if mode == "full":
1131
- return ret
1844
+ retval = ret
1132
1845
  elif mode == "same":
1133
- return _centered(ret, s1)
1846
+ retval = _centered(ret, s1)
1134
1847
  elif mode == "valid":
1135
- return _centered(ret, s1 - s2 + 1)
1848
+ retval = _centered(ret, s1 - s2 + 1)
1136
1849
 
1850
+ return retval
1851
+
1852
+
1853
+ def gccproduct(
1854
+ fft1: NDArray,
1855
+ fft2: NDArray,
1856
+ weighting: str,
1857
+ threshfrac: float = 0.1,
1858
+ compress: bool = False,
1859
+ displayplots: bool = False,
1860
+ ) -> NDArray:
1861
+ """
1862
+ Compute the generalized cross-correlation (GCC) product with optional weighting.
1137
1863
 
1138
- def gccproduct(fft1, fft2, weighting, threshfrac=0.1, compress=False, displayplots=False):
1139
- """Calculate product for generalized crosscorrelation.
1864
+ This function computes the GCC product of two FFT arrays, applying a specified
1865
+ weighting scheme to enhance correlation performance. It supports several weighting
1866
+ methods including 'liang', 'eckart', 'phat', and 'regressor'. The result can be
1867
+ thresholded and optionally compressed to improve visualization and reduce noise.
1140
1868
 
1141
1869
  Parameters
1142
1870
  ----------
1143
- fft1
1144
- fft2
1145
- weighting
1146
- threshfrac
1147
- displayplots
1871
+ fft1 : NDArray
1872
+ First FFT array (complex-valued).
1873
+ fft2 : NDArray
1874
+ Second FFT array (complex-valued).
1875
+ weighting : str
1876
+ Weighting method to apply. Options are:
1877
+ - 'liang': Liang weighting
1878
+ - 'eckart': Eckart weighting
1879
+ - 'phat': PHAT (Phase Transform) weighting
1880
+ - 'regressor': Regressor-based weighting (uses fft2 as reference)
1881
+ - 'None': No weighting applied.
1882
+ threshfrac : float, optional
1883
+ Threshold fraction used to determine the minimum value for output masking.
1884
+ Default is 0.1.
1885
+ compress : bool, optional
1886
+ If True, compress the weighting function using 10th and 90th percentiles.
1887
+ Default is False.
1888
+ displayplots : bool, optional
1889
+ If True, display the reciprocal weighting function as a plot.
1890
+ Default is False.
1148
1891
 
1149
1892
  Returns
1150
1893
  -------
1151
- product
1894
+ NDArray
1895
+ The weighted GCC product. The output is of the same shape as the input arrays.
1896
+ If `weighting` is 'None', the raw product is returned.
1897
+ If `threshfrac` is 0, a zero array of the same shape is returned.
1898
+
1899
+ Notes
1900
+ -----
1901
+ The weighting functions are applied element-wise and are designed to suppress
1902
+ noise and enhance correlation peaks. The 'phat' weighting is commonly used in
1903
+ speech and signal processing due to its robustness.
1904
+
1905
+ Examples
1906
+ --------
1907
+ >>> import numpy as np
1908
+ >>> fft1 = np.random.rand(100) + 1j * np.random.rand(100)
1909
+ >>> fft2 = np.random.rand(100) + 1j * np.random.rand(100)
1910
+ >>> result = gccproduct(fft1, fft2, weighting='phat', threshfrac=0.05)
1152
1911
  """
1153
1912
  product = fft1 * fft2
1154
1913
  if weighting == "None":
@@ -1197,17 +1956,75 @@ def gccproduct(fft1, fft2, weighting, threshfrac=0.1, compress=False, displayplo
1197
1956
 
1198
1957
 
1199
1958
  def aligntcwithref(
1200
- fixedtc,
1201
- movingtc,
1202
- Fs,
1203
- lagmin=-30,
1204
- lagmax=30,
1205
- refine=True,
1206
- zerooutbadfit=False,
1207
- widthmax=1000.0,
1208
- display=False,
1209
- verbose=False,
1210
- ):
1959
+ fixedtc: NDArray,
1960
+ movingtc: NDArray,
1961
+ Fs: float,
1962
+ lagmin: float = -30,
1963
+ lagmax: float = 30,
1964
+ refine: bool = True,
1965
+ zerooutbadfit: bool = False,
1966
+ widthmax: float = 1000.0,
1967
+ display: bool = False,
1968
+ verbose: bool = False,
1969
+ ) -> Tuple[NDArray, float, float, int]:
1970
+ """
1971
+ Align a moving timecourse to a fixed reference timecourse using cross-correlation.
1972
+
1973
+ This function computes the cross-correlation between two timecourses and finds the
1974
+ optimal time lag that maximizes their similarity. The moving timecourse is then
1975
+ aligned to the fixed one using this lag.
1976
+
1977
+ Parameters
1978
+ ----------
1979
+ fixedtc : ndarray
1980
+ The reference timecourse to which the moving timecourse will be aligned.
1981
+ movingtc : ndarray
1982
+ The timecourse to be aligned to the fixed timecourse.
1983
+ Fs : float
1984
+ Sampling frequency of the timecourses in Hz.
1985
+ lagmin : float, optional
1986
+ Minimum lag to consider in seconds. Default is -30.
1987
+ lagmax : float, optional
1988
+ Maximum lag to consider in seconds. Default is 30.
1989
+ refine : bool, optional
1990
+ If True, refine the lag estimate using Gaussian fitting. Default is True.
1991
+ zerooutbadfit : bool, optional
1992
+ If True, zero out the cross-correlation values for bad fits. Default is False.
1993
+ widthmax : float, optional
1994
+ Maximum allowed width of the Gaussian fit in samples. Default is 1000.0.
1995
+ display : bool, optional
1996
+ If True, display plots of the cross-correlation and aligned timecourses. Default is False.
1997
+ verbose : bool, optional
1998
+ If True, print detailed information about the cross-correlation results. Default is False.
1999
+
2000
+ Returns
2001
+ -------
2002
+ tuple
2003
+ A tuple containing:
2004
+ - aligneddata : ndarray
2005
+ The moving timecourse aligned to the fixed timecourse.
2006
+ - maxdelay : float
2007
+ The estimated time lag (in seconds) that maximizes cross-correlation.
2008
+ - maxval : float
2009
+ The maximum cross-correlation value.
2010
+ - failreason : int
2011
+ Reason for failure (0 = success, other values indicate specific failure types).
2012
+
2013
+ Notes
2014
+ -----
2015
+ This function uses `fastcorrelate` for efficient cross-correlation computation and
2016
+ `tide_fit.findmaxlag_gauss` to estimate the optimal lag with optional Gaussian refinement.
2017
+ The alignment is performed using `tide_resample.doresample`.
2018
+
2019
+ Examples
2020
+ --------
2021
+ >>> import numpy as np
2022
+ >>> from typing import Tuple
2023
+ >>> fixed = np.random.rand(1000)
2024
+ >>> moving = np.roll(fixed, 10) # shift by 10 samples
2025
+ >>> aligned, delay, corr, fail = aligntcwithref(fixed, moving, Fs=100)
2026
+ >>> print(f"Estimated delay: {delay}s")
2027
+ """
1211
2028
  # now fixedtc and 2 are on the same timescales
1212
2029
  thexcorr = fastcorrelate(tide_math.corrnormalize(fixedtc), tide_math.corrnormalize(movingtc))
1213
2030
  xcorrlen = len(thexcorr)