rapidtide 3.0.11__py3-none-any.whl → 3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. rapidtide/Colortables.py +492 -27
  2. rapidtide/OrthoImageItem.py +1049 -46
  3. rapidtide/RapidtideDataset.py +1533 -86
  4. rapidtide/_version.py +3 -3
  5. rapidtide/calccoherence.py +196 -29
  6. rapidtide/calcnullsimfunc.py +188 -40
  7. rapidtide/calcsimfunc.py +242 -42
  8. rapidtide/correlate.py +1203 -383
  9. rapidtide/data/examples/src/testLD +56 -0
  10. rapidtide/data/examples/src/testalign +1 -1
  11. rapidtide/data/examples/src/testdelayvar +0 -1
  12. rapidtide/data/examples/src/testfmri +53 -3
  13. rapidtide/data/examples/src/testglmfilt +5 -5
  14. rapidtide/data/examples/src/testhappy +29 -7
  15. rapidtide/data/examples/src/testppgproc +17 -0
  16. rapidtide/data/examples/src/testrolloff +11 -0
  17. rapidtide/data/models/model_cnn_pytorch/best_model.pth +0 -0
  18. rapidtide/data/models/model_cnn_pytorch/loss.png +0 -0
  19. rapidtide/data/models/model_cnn_pytorch/loss.txt +1 -0
  20. rapidtide/data/models/model_cnn_pytorch/model.pth +0 -0
  21. rapidtide/data/models/model_cnn_pytorch/model_meta.json +68 -0
  22. rapidtide/decorators.py +91 -0
  23. rapidtide/dlfilter.py +2226 -110
  24. rapidtide/dlfiltertorch.py +4842 -0
  25. rapidtide/externaltools.py +327 -12
  26. rapidtide/fMRIData_class.py +79 -40
  27. rapidtide/filter.py +1899 -810
  28. rapidtide/fit.py +2011 -581
  29. rapidtide/genericmultiproc.py +93 -18
  30. rapidtide/happy_supportfuncs.py +2047 -172
  31. rapidtide/helper_classes.py +584 -43
  32. rapidtide/io.py +2370 -372
  33. rapidtide/linfitfiltpass.py +346 -99
  34. rapidtide/makelaggedtcs.py +210 -24
  35. rapidtide/maskutil.py +448 -62
  36. rapidtide/miscmath.py +827 -121
  37. rapidtide/multiproc.py +210 -22
  38. rapidtide/patchmatch.py +242 -42
  39. rapidtide/peakeval.py +31 -31
  40. rapidtide/ppgproc.py +2203 -0
  41. rapidtide/qualitycheck.py +352 -39
  42. rapidtide/refinedelay.py +431 -57
  43. rapidtide/refineregressor.py +494 -189
  44. rapidtide/resample.py +671 -185
  45. rapidtide/scripts/applyppgproc.py +28 -0
  46. rapidtide/scripts/showxcorr_legacy.py +7 -7
  47. rapidtide/scripts/stupidramtricks.py +15 -17
  48. rapidtide/simFuncClasses.py +1052 -77
  49. rapidtide/simfuncfit.py +269 -69
  50. rapidtide/stats.py +540 -238
  51. rapidtide/tests/happycomp +9 -0
  52. rapidtide/tests/test_cleanregressor.py +1 -2
  53. rapidtide/tests/test_dlfiltertorch.py +627 -0
  54. rapidtide/tests/test_findmaxlag.py +24 -8
  55. rapidtide/tests/test_fullrunhappy_v1.py +0 -2
  56. rapidtide/tests/test_fullrunhappy_v2.py +0 -2
  57. rapidtide/tests/test_fullrunhappy_v3.py +11 -4
  58. rapidtide/tests/test_fullrunhappy_v4.py +10 -2
  59. rapidtide/tests/test_fullrunrapidtide_v7.py +1 -1
  60. rapidtide/tests/test_getparsers.py +11 -3
  61. rapidtide/tests/test_refinedelay.py +0 -1
  62. rapidtide/tests/test_simroundtrip.py +16 -8
  63. rapidtide/tests/test_stcorrelate.py +3 -1
  64. rapidtide/tests/utils.py +9 -8
  65. rapidtide/tidepoolTemplate.py +142 -38
  66. rapidtide/tidepoolTemplate_alt.py +165 -44
  67. rapidtide/tidepoolTemplate_big.py +189 -52
  68. rapidtide/util.py +1217 -118
  69. rapidtide/voxelData.py +684 -37
  70. rapidtide/wiener.py +136 -23
  71. rapidtide/wiener2.py +113 -7
  72. rapidtide/workflows/adjustoffset.py +105 -3
  73. rapidtide/workflows/aligntcs.py +85 -2
  74. rapidtide/workflows/applydlfilter.py +87 -10
  75. rapidtide/workflows/applyppgproc.py +540 -0
  76. rapidtide/workflows/atlasaverage.py +210 -47
  77. rapidtide/workflows/atlastool.py +100 -3
  78. rapidtide/workflows/calcSimFuncMap.py +288 -69
  79. rapidtide/workflows/calctexticc.py +201 -9
  80. rapidtide/workflows/ccorrica.py +101 -6
  81. rapidtide/workflows/cleanregressor.py +165 -31
  82. rapidtide/workflows/delayvar.py +171 -23
  83. rapidtide/workflows/diffrois.py +81 -3
  84. rapidtide/workflows/endtidalproc.py +144 -4
  85. rapidtide/workflows/fdica.py +195 -15
  86. rapidtide/workflows/filtnifti.py +70 -3
  87. rapidtide/workflows/filttc.py +74 -3
  88. rapidtide/workflows/fitSimFuncMap.py +202 -51
  89. rapidtide/workflows/fixtr.py +73 -3
  90. rapidtide/workflows/gmscalc.py +113 -3
  91. rapidtide/workflows/happy.py +801 -199
  92. rapidtide/workflows/happy2std.py +144 -12
  93. rapidtide/workflows/happy_parser.py +163 -23
  94. rapidtide/workflows/histnifti.py +118 -2
  95. rapidtide/workflows/histtc.py +84 -3
  96. rapidtide/workflows/linfitfilt.py +117 -4
  97. rapidtide/workflows/localflow.py +328 -28
  98. rapidtide/workflows/mergequality.py +79 -3
  99. rapidtide/workflows/niftidecomp.py +322 -18
  100. rapidtide/workflows/niftistats.py +174 -4
  101. rapidtide/workflows/pairproc.py +98 -4
  102. rapidtide/workflows/pairwisemergenifti.py +85 -2
  103. rapidtide/workflows/parser_funcs.py +1421 -40
  104. rapidtide/workflows/physiofreq.py +137 -11
  105. rapidtide/workflows/pixelcomp.py +207 -5
  106. rapidtide/workflows/plethquality.py +103 -21
  107. rapidtide/workflows/polyfitim.py +151 -11
  108. rapidtide/workflows/proj2flow.py +75 -2
  109. rapidtide/workflows/rankimage.py +111 -4
  110. rapidtide/workflows/rapidtide.py +368 -76
  111. rapidtide/workflows/rapidtide2std.py +98 -2
  112. rapidtide/workflows/rapidtide_parser.py +109 -9
  113. rapidtide/workflows/refineDelayMap.py +144 -33
  114. rapidtide/workflows/refineRegressor.py +675 -96
  115. rapidtide/workflows/regressfrommaps.py +161 -37
  116. rapidtide/workflows/resamplenifti.py +85 -3
  117. rapidtide/workflows/resampletc.py +91 -3
  118. rapidtide/workflows/retrolagtcs.py +99 -9
  119. rapidtide/workflows/retroregress.py +176 -26
  120. rapidtide/workflows/roisummarize.py +174 -5
  121. rapidtide/workflows/runqualitycheck.py +71 -3
  122. rapidtide/workflows/showarbcorr.py +149 -6
  123. rapidtide/workflows/showhist.py +86 -2
  124. rapidtide/workflows/showstxcorr.py +160 -3
  125. rapidtide/workflows/showtc.py +159 -3
  126. rapidtide/workflows/showxcorrx.py +190 -10
  127. rapidtide/workflows/showxy.py +185 -15
  128. rapidtide/workflows/simdata.py +264 -38
  129. rapidtide/workflows/spatialfit.py +77 -2
  130. rapidtide/workflows/spatialmi.py +250 -27
  131. rapidtide/workflows/spectrogram.py +305 -32
  132. rapidtide/workflows/synthASL.py +154 -3
  133. rapidtide/workflows/tcfrom2col.py +76 -2
  134. rapidtide/workflows/tcfrom3col.py +74 -2
  135. rapidtide/workflows/tidepool.py +2971 -130
  136. rapidtide/workflows/utils.py +19 -14
  137. rapidtide/workflows/utils_doc.py +293 -0
  138. rapidtide/workflows/variabilityizer.py +116 -3
  139. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/METADATA +10 -8
  140. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/RECORD +144 -128
  141. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/entry_points.txt +1 -0
  142. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/WHEEL +0 -0
  143. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/licenses/LICENSE +0 -0
  144. {rapidtide-3.0.11.dist-info → rapidtide-3.1.1.dist-info}/top_level.txt +0 -0
rapidtide/correlate.py CHANGED
@@ -19,9 +19,11 @@
19
19
  """Functions for calculating correlations and similar metrics between arrays."""
20
20
  import logging
21
21
  import warnings
22
+ from typing import Any, Callable, Optional, Tuple, Union
22
23
 
23
24
  import matplotlib.pyplot as plt
24
25
  import numpy as np
26
+ from numpy.typing import NDArray
25
27
 
26
28
  with warnings.catch_warnings():
27
29
  warnings.simplefilter("ignore")
@@ -44,6 +46,7 @@ import rapidtide.miscmath as tide_math
44
46
  import rapidtide.resample as tide_resample
45
47
  import rapidtide.stats as tide_stats
46
48
  import rapidtide.util as tide_util
49
+ from rapidtide.decorators import conditionaljit
47
50
 
48
51
  if pyfftwpresent:
49
52
  fftpack = pyfftw.interfaces.scipy_fftpack
@@ -56,60 +59,62 @@ MAXLINES = 10000000
56
59
  donotbeaggressive = True
57
60
 
58
61
 
59
- # ----------------------------------------- Conditional imports ---------------------------------------
60
- try:
61
- from numba import jit
62
- except ImportError:
63
- donotusenumba = True
64
- else:
65
- donotusenumba = False
66
-
67
-
68
- def conditionaljit():
69
- """Wrap functions in jit if numba is enabled."""
70
-
71
- def resdec(f):
72
- if donotusenumba:
73
- return f
74
- return jit(f, nopython=True)
75
-
76
- return resdec
77
-
78
-
79
- def disablenumba():
80
- """Set a global variable to disable numba."""
81
- global donotusenumba
82
- donotusenumba = True
83
-
84
-
85
62
  # --------------------------- Correlation functions -------------------------------------------------
86
63
  def check_autocorrelation(
87
- corrscale,
88
- thexcorr,
89
- delta=0.05,
90
- acampthresh=0.1,
91
- aclagthresh=10.0,
92
- displayplots=False,
93
- detrendorder=1,
94
- debug=False,
95
- ):
96
- """Check for autocorrelation in an array.
64
+ corrscale: NDArray,
65
+ thexcorr: NDArray,
66
+ delta: float = 0.05,
67
+ acampthresh: float = 0.1,
68
+ aclagthresh: float = 10.0,
69
+ displayplots: bool = False,
70
+ detrendorder: int = 1,
71
+ debug: bool = False,
72
+ ) -> Tuple[Optional[float], Optional[float]]:
73
+ """
74
+ Check for autocorrelation peaks in a cross-correlation signal and fit a Gaussian to the sidelobe.
75
+
76
+ This function identifies peaks in the cross-correlation signal and, if a significant
77
+ sidelobe is detected (based on amplitude and lag thresholds), fits a Gaussian function
78
+ to estimate the sidelobe's time and amplitude.
97
79
 
98
80
  Parameters
99
81
  ----------
100
- corrscale
101
- thexcorr
102
- delta
103
- acampthresh
104
- aclagthresh
105
- displayplots
106
- windowfunc
107
- detrendorder
82
+ corrscale : NDArray
83
+ Array of time lags corresponding to the cross-correlation values.
84
+ thexcorr : NDArray
85
+ Array of cross-correlation values.
86
+ delta : float, optional
87
+ Minimum distance between peaks, default is 0.05.
88
+ acampthresh : float, optional
89
+ Amplitude threshold for detecting sidelobes, default is 0.1.
90
+ aclagthresh : float, optional
91
+ Lag threshold beyond which sidelobes are ignored, default is 10.0.
92
+ displayplots : bool, optional
93
+ If True, display the cross-correlation plot with detected peaks, default is False.
94
+ detrendorder : int, optional
95
+ Order of detrending to apply to the signal, default is 1.
96
+ debug : bool, optional
97
+ If True, print debug information, default is False.
108
98
 
109
99
  Returns
110
100
  -------
111
- sidelobetime
112
- sidelobeamp
101
+ Tuple[Optional[float], Optional[float]]
102
+ A tuple containing the estimated sidelobe time and amplitude if a valid sidelobe is found,
103
+ otherwise (None, None).
104
+
105
+ Notes
106
+ -----
107
+ - The function uses `peakdetect` to find peaks in the cross-correlation.
108
+ - A Gaussian fit is performed only if a peak is found beyond the zero-lag point and
109
+ satisfies the amplitude and lag thresholds.
110
+ - The fit is performed on a window around the detected sidelobe.
111
+
112
+ Examples
113
+ --------
114
+ >>> corrscale = np.linspace(0, 20, 100)
115
+ >>> thexcorr = np.exp(-0.5 * (corrscale - 5)**2 / 2) + 0.1 * np.random.rand(100)
116
+ >>> time, amp = check_autocorrelation(corrscale, thexcorr, delta=0.1, acampthresh=0.05)
117
+ >>> print(f"Sidelobe time: {time}, Amplitude: {amp}")
113
118
  """
114
119
  if debug:
115
120
  print("check_autocorrelation:")
@@ -176,31 +181,61 @@ def check_autocorrelation(
176
181
 
177
182
 
178
183
  def shorttermcorr_1D(
179
- data1,
180
- data2,
181
- sampletime,
182
- windowtime,
183
- samplestep=1,
184
- detrendorder=0,
185
- windowfunc="hamming",
186
- ):
187
- """Calculate short-term sliding-window correlation between two 1D arrays.
184
+ data1: NDArray,
185
+ data2: NDArray,
186
+ sampletime: float,
187
+ windowtime: float,
188
+ samplestep: int = 1,
189
+ detrendorder: int = 0,
190
+ windowfunc: str = "hamming",
191
+ ) -> Tuple[NDArray, NDArray, NDArray]:
192
+ """
193
+ Compute short-term cross-correlation between two 1D signals using sliding windows.
194
+
195
+ This function calculates the Pearson correlation coefficient between two signals
196
+ over short time windows, allowing for the analysis of time-varying correlations.
197
+ The correlation is computed for overlapping windows across the input data,
198
+ with optional detrending and windowing applied to each segment.
188
199
 
189
200
  Parameters
190
201
  ----------
191
- data1
192
- data2
193
- sampletime
194
- windowtime
195
- samplestep
196
- detrendorder
197
- windowfunc
202
+ data1 : NDArray
203
+ First input signal (1D array).
204
+ data2 : NDArray
205
+ Second input signal (1D array). Must have the same length as `data1`.
206
+ sampletime : float
207
+ Time interval between consecutive samples in seconds.
208
+ windowtime : float
209
+ Length of the sliding window in seconds.
210
+ samplestep : int, optional
211
+ Step size (in samples) between consecutive windows. Default is 1.
212
+ detrendorder : int, optional
213
+ Order of detrending to apply before correlation. 0 means no detrending.
214
+ Default is 0.
215
+ windowfunc : str, optional
216
+ Window function to apply to each segment. Default is "hamming".
198
217
 
199
218
  Returns
200
219
  -------
201
- times
202
- corrpertime
203
- ppertime
220
+ times : NDArray
221
+ Array of time values corresponding to the center of each window.
222
+ corrpertime : NDArray
223
+ Array of Pearson correlation coefficients for each window.
224
+ ppertime : NDArray
225
+ Array of p-values associated with the correlation coefficients.
226
+
227
+ Notes
228
+ -----
229
+ The function uses `tide_math.corrnormalize` for normalization and detrending
230
+ of signal segments, and `scipy.stats.pearsonr` for computing the correlation.
231
+
232
+ Examples
233
+ --------
234
+ >>> import numpy as np
235
+ >>> data1 = np.random.randn(1000)
236
+ >>> data2 = np.random.randn(1000)
237
+ >>> times, corr, pvals = shorttermcorr_1D(data1, data2, 0.1, 1.0)
238
+ >>> print(f"Correlation at time {times[0]:.2f}: {corr[0]:.3f}")
204
239
  """
205
240
  windowsize = int(windowtime // sampletime)
206
241
  halfwindow = int((windowsize + 1) // 2)
@@ -218,10 +253,11 @@ def shorttermcorr_1D(
218
253
  detrendorder=detrendorder,
219
254
  windowfunc=windowfunc,
220
255
  )
221
- thepcorr = sp.stats.pearsonr(dataseg1, dataseg2)
256
+ thepearsonresult = sp.stats.pearsonr(dataseg1, dataseg2)
257
+ thepcorrR, thepcorrp = thepearsonresult.statistic, thepearsonresult.pvalue
222
258
  times.append(i * sampletime)
223
- corrpertime.append(thepcorr[0])
224
- ppertime.append(thepcorr[1])
259
+ corrpertime.append(thepcorrR)
260
+ ppertime.append(thepcorrp)
225
261
  return (
226
262
  np.asarray(times, dtype="float64"),
227
263
  np.asarray(corrpertime, dtype="float64"),
@@ -230,42 +266,85 @@ def shorttermcorr_1D(
230
266
 
231
267
 
232
268
  def shorttermcorr_2D(
233
- data1,
234
- data2,
235
- sampletime,
236
- windowtime,
237
- samplestep=1,
238
- laglimits=None,
239
- weighting="None",
240
- zeropadding=0,
241
- windowfunc="None",
242
- detrendorder=0,
243
- compress=False,
244
- displayplots=False,
245
- ):
246
- """Calculate short-term sliding-window correlation between two 2D arrays.
269
+ data1: NDArray,
270
+ data2: NDArray,
271
+ sampletime: float,
272
+ windowtime: float,
273
+ samplestep: int = 1,
274
+ laglimits: Optional[Tuple[float, float]] = None,
275
+ weighting: str = "None",
276
+ zeropadding: int = 0,
277
+ windowfunc: str = "None",
278
+ detrendorder: int = 0,
279
+ compress: bool = False,
280
+ displayplots: bool = False,
281
+ ) -> Tuple[NDArray, NDArray, NDArray, NDArray, NDArray]:
282
+ """
283
+ Compute short-term cross-correlations between two 1D signals over sliding windows.
284
+
285
+ This function computes the cross-correlation between two input signals (`data1` and `data2`)
286
+ using a sliding window approach. For each window, the cross-correlation is computed and
287
+ the peak lag and correlation coefficient are extracted. The function supports detrending,
288
+ windowing, and various correlation weighting schemes.
247
289
 
248
290
  Parameters
249
291
  ----------
250
- data1
251
- data2
252
- sampletime
253
- windowtime
254
- samplestep
255
- laglimits
256
- weighting
257
- zeropadding
258
- windowfunc
259
- detrendorder
260
- displayplots
292
+ data1 : NDArray
293
+ First input signal (1D array).
294
+ data2 : NDArray
295
+ Second input signal (1D array). Must be of the same length as `data1`.
296
+ sampletime : float
297
+ Sampling interval of the input signals in seconds.
298
+ windowtime : float
299
+ Length of the sliding window in seconds.
300
+ samplestep : int, optional
301
+ Step size (in samples) for the sliding window. Default is 1.
302
+ laglimits : Tuple[float, float], optional
303
+ Minimum and maximum lag limits (in seconds) for peak detection.
304
+ If None, defaults to ±windowtime/2.
305
+ weighting : str, optional
306
+ Type of weighting to apply during cross-correlation ('None', 'hamming', etc.).
307
+ Default is 'None'.
308
+ zeropadding : int, optional
309
+ Zero-padding factor for the FFT-based correlation. Default is 0.
310
+ windowfunc : str, optional
311
+ Type of window function to apply ('None', 'hamming', etc.). Default is 'None'.
312
+ detrendorder : int, optional
313
+ Order of detrending to apply before correlation (0 = no detrend, 1 = linear, etc.).
314
+ Default is 0.
315
+ compress : bool, optional
316
+ Whether to compress the correlation result. Default is False.
317
+ displayplots : bool, optional
318
+ Whether to display intermediate plots (e.g., correlation matrix). Default is False.
261
319
 
262
320
  Returns
263
321
  -------
264
- times
265
- xcorrpertime
266
- Rvals
267
- delayvals
268
- valid
322
+ times : NDArray
323
+ Array of time values corresponding to the center of each window.
324
+ xcorrpertime : NDArray
325
+ Array of cross-correlation functions for each window.
326
+ Rvals : NDArray
327
+ Correlation coefficients for each window.
328
+ delayvals : NDArray
329
+ Estimated time delays (lags) for each window.
330
+ valid : NDArray
331
+ Binary array indicating whether the peak detection was successful (1) or failed (0).
332
+
333
+ Notes
334
+ -----
335
+ - The function uses `fastcorrelate` for efficient cross-correlation computation.
336
+ - Peak detection is performed using `tide_fit.findmaxlag_gauss`.
337
+ - If `displayplots` is True, an image of the cross-correlations is shown.
338
+
339
+ Examples
340
+ --------
341
+ >>> import numpy as np
342
+ >>> t = np.linspace(0, 10, 1000)
343
+ >>> signal1 = np.sin(2 * np.pi * 0.5 * t)
344
+ >>> signal2 = np.sin(2 * np.pi * 0.5 * t + 0.1)
345
+ >>> times, xcorrs, Rvals, delays, valid = shorttermcorr_2D(
346
+ ... signal1, signal2, sampletime=0.01, windowtime=1.0
347
+ ... )
269
348
  """
270
349
  windowsize = int(windowtime // sampletime)
271
350
  halfwindow = int((windowsize + 1) // 2)
@@ -356,74 +435,230 @@ def shorttermcorr_2D(
356
435
  )
357
436
 
358
437
 
359
- def calc_MI(x, y, bins=50):
360
- """Calculate mutual information between two arrays.
438
+ def calc_MI(x: NDArray, y: NDArray, bins: int = 50) -> float:
439
+ """
440
+ Calculate mutual information between two arrays.
441
+
442
+ Parameters
443
+ ----------
444
+ x : array-like
445
+ First array of data points
446
+ y : array-like
447
+ Second array of data points
448
+ bins : int, optional
449
+ Number of bins to use for histogram estimation, default is 50
450
+
451
+ Returns
452
+ -------
453
+ float
454
+ Mutual information between x and y
361
455
 
362
456
  Notes
363
457
  -----
364
- From https://stackoverflow.com/questions/20491028/
365
- optimal-way-to-compute-pairwise-mutual-information-using-numpy/
458
+ This implementation uses 2D histogram estimation followed by mutual information
459
+ calculation. The method is based on the approach from:
460
+ https://stackoverflow.com/questions/20491028/optimal-way-to-compute-pairwise-mutual-information-using-numpy/
366
461
  20505476#20505476
462
+
463
+ Examples
464
+ --------
465
+ >>> import numpy as np
466
+ >>> x = np.random.randn(1000)
467
+ >>> y = x + np.random.randn(1000) * 0.5
468
+ >>> mi = calc_MI(x, y)
469
+ >>> print(f"Mutual information: {mi:.3f}")
367
470
  """
368
471
  c_xy = np.histogram2d(x, y, bins)[0]
369
472
  mi = mutual_info_score(None, None, contingency=c_xy)
370
473
  return mi
371
474
 
372
475
 
476
+ # @conditionaljit()
477
+ def mutual_info_2d_fast(
478
+ x: NDArray[np.floating[Any]],
479
+ y: NDArray[np.floating[Any]],
480
+ bins: Tuple[NDArray, NDArray],
481
+ sigma: float = 1,
482
+ normalized: bool = True,
483
+ EPS: float = 1.0e-6,
484
+ debug: bool = False,
485
+ ) -> float:
486
+ """
487
+ Compute (normalized) mutual information between two 1D variates from a joint histogram.
488
+
489
+ Parameters
490
+ ----------
491
+ x : 1D NDArray[np.floating[Any]]
492
+ First variable.
493
+ y : 1D NDArray[np.floating[Any]]
494
+ Second variable.
495
+ bins : tuple of NDArray
496
+ Bin edges for the histogram. The first element corresponds to `x` and the second to `y`.
497
+ sigma : float, optional
498
+ Sigma for Gaussian smoothing of the joint histogram. Default is 1.
499
+ normalized : bool, optional
500
+ If True, compute normalized mutual information as defined in [1]_. Default is True.
501
+ EPS : float, optional
502
+ Small constant to avoid numerical errors in logarithms. Default is 1e-6.
503
+ debug : bool, optional
504
+ If True, print intermediate values for debugging. Default is False.
505
+
506
+ Returns
507
+ -------
508
+ float
509
+ The computed mutual information (or normalized mutual information if `normalized=True`).
510
+
511
+ Notes
512
+ -----
513
+ This function computes mutual information using a 2D histogram and Gaussian smoothing.
514
+ The normalization follows the approach described in [1]_.
515
+
516
+ References
517
+ ----------
518
+ .. [1] Colin Studholme, David John Hawkes, Derek L.G. Hill (1998).
519
+ "Normalized entropy measure for multimodality image alignment".
520
+ in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
521
+
522
+ Examples
523
+ --------
524
+ >>> import numpy as np
525
+ >>> x = np.random.randn(1000)
526
+ >>> y = np.random.randn(1000)
527
+ >>> bins = (np.linspace(-3, 3, 64), np.linspace(-3, 3, 64))
528
+ >>> mi = mutual_info_2d_fast(x, y, bins)
529
+ >>> print(mi)
530
+ """
531
+ xstart = bins[0][0]
532
+ xend = bins[0][-1]
533
+ ystart = bins[1][0]
534
+ yend = bins[1][-1]
535
+ numxbins = int(len(bins[0]) - 1)
536
+ numybins = int(len(bins[1]) - 1)
537
+ cuts = (x >= xstart) & (x < xend) & (y >= ystart) & (y < yend)
538
+ c = ((x[cuts] - xstart) / (xend - xstart) * numxbins).astype(np.int_)
539
+ c += ((y[cuts] - ystart) / (yend - ystart) * numybins).astype(np.int_) * numxbins
540
+ jh = np.bincount(c, minlength=numxbins * numybins).reshape(numxbins, numybins)
541
+
542
+ return proc_MI_histogram(jh, sigma=sigma, normalized=normalized, EPS=EPS, debug=debug)
543
+
544
+
373
545
  # @conditionaljit()
374
546
  def mutual_info_2d(
375
- x, y, sigma=1, bins=(256, 256), fast=False, normalized=True, EPS=1.0e-6, debug=False
376
- ):
377
- """Compute (normalized) mutual information between two 1D variate from a joint histogram.
547
+ x: NDArray[np.floating[Any]],
548
+ y: NDArray[np.floating[Any]],
549
+ bins: Tuple[int, int],
550
+ sigma: float = 1,
551
+ normalized: bool = True,
552
+ EPS: float = 1.0e-6,
553
+ debug: bool = False,
554
+ ) -> float:
555
+ """
556
+ Compute (normalized) mutual information between two 1D variates from a joint histogram.
378
557
 
379
558
  Parameters
380
559
  ----------
381
- x : 1D array
382
- first variable
383
- y : 1D array
384
- second variable
560
+ x : 1D NDArray[np.floating[Any]]
561
+ First variable.
562
+ y : 1D NDArray[np.floating[Any]]
563
+ Second variable.
564
+ bins : tuple of int
565
+ Number of bins for the histogram. The first element is the number of bins for `x`
566
+ and the second for `y`.
385
567
  sigma : float, optional
386
- Sigma for Gaussian smoothing of the joint histogram.
387
- Default = 1.
388
- bins : tuple, optional
389
- fast : bool, optional
390
- normalized : bool
391
- If True, this will calculate the normalized mutual information from [1]_.
392
- Default = False.
568
+ Sigma for Gaussian smoothing of the joint histogram. Default is 1.
569
+ normalized : bool, optional
570
+ If True, compute normalized mutual information as defined in [1]_. Default is True.
393
571
  EPS : float, optional
394
- Default = 1.0e-6.
572
+ Small constant to avoid numerical errors in logarithms. Default is 1e-6.
573
+ debug : bool, optional
574
+ If True, print intermediate values for debugging. Default is False.
395
575
 
396
576
  Returns
397
577
  -------
398
- nmi: float
399
- the computed similarity measure
578
+ float
579
+ The computed mutual information (or normalized mutual information if `normalized=True`).
400
580
 
401
581
  Notes
402
582
  -----
403
- From Ionnis Pappas
404
- BBF added the precaching (fast) option
583
+ This function computes mutual information using a 2D histogram and Gaussian smoothing.
584
+ The normalization follows the approach described in [1]_.
405
585
 
406
586
  References
407
587
  ----------
408
588
  .. [1] Colin Studholme, David John Hawkes, Derek L.G. Hill (1998).
409
589
  "Normalized entropy measure for multimodality image alignment".
410
590
  in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
591
+
592
+ Examples
593
+ --------
594
+ >>> import numpy as np
595
+ >>> x = np.random.randn(1000)
596
+ >>> y = np.random.randn(1000)
597
+ >>> mi = mutual_info_2d(x, y)
598
+ >>> print(mi)
599
+ """
600
+ jh, xbins, ybins = np.histogram2d(x, y, bins=bins)
601
+ if debug:
602
+ print(f"{xbins} {ybins}")
603
+
604
+ return proc_MI_histogram(jh, sigma=sigma, normalized=normalized, EPS=EPS, debug=debug)
605
+
606
+
607
+ def proc_MI_histogram(
608
+ jh: NDArray[np.floating[Any]],
609
+ sigma: float = 1,
610
+ normalized: bool = True,
611
+ EPS: float = 1.0e-6,
612
+ debug: bool = False,
613
+ ) -> float:
614
+ """
615
+ Compute the mutual information (MI) between two variables from a joint histogram.
616
+
617
+ This function calculates mutual information using the joint histogram of two variables,
618
+ applying Gaussian smoothing and computing entropy-based MI. It supports both normalized
619
+ and unnormalized versions of the mutual information.
620
+
621
+ Parameters
622
+ ----------
623
+ jh : ndarray of shape (m, n)
624
+ Joint histogram of two variables. Should be a 2D array of floating point values.
625
+ sigma : float, optional
626
+ Standard deviation for Gaussian smoothing of the joint histogram. Default is 1.0.
627
+ normalized : bool, optional
628
+ If True, returns normalized mutual information. If False, returns unnormalized
629
+ mutual information. Default is True.
630
+ EPS : float, optional
631
+ Small constant added to the histogram to avoid numerical issues in log computation.
632
+ Default is 1e-6.
633
+ debug : bool, optional
634
+ If True, prints intermediate values for debugging purposes. Default is False.
635
+
636
+ Returns
637
+ -------
638
+ float
639
+ The computed mutual information (MI) between the two variables. The value is
640
+ positive and indicates the amount of information shared between the variables.
641
+
642
+ Notes
643
+ -----
644
+ The function applies Gaussian smoothing to the joint histogram before computing
645
+ marginal and joint entropies. The mutual information is computed as:
646
+
647
+ .. math::
648
+ MI = \\frac{H(X) + H(Y)}{H(X,Y)} - 1
649
+
650
+ where :math:`H(X)`, :math:`H(Y)`, and :math:`H(X,Y)` are the marginal and joint entropies,
651
+ respectively. If `normalized=False`, the unnormalized MI is returned instead.
652
+
653
+ Examples
654
+ --------
655
+ >>> import numpy as np
656
+ >>> from scipy import ndimage
657
+ >>> jh = np.random.rand(10, 10)
658
+ >>> mi = proc_MI_histogram(jh, sigma=0.5, normalized=True)
659
+ >>> print(mi)
660
+ 0.123456789
411
661
  """
412
- if fast:
413
- xstart = bins[0][0]
414
- xend = bins[0][-1]
415
- ystart = bins[1][0]
416
- yend = bins[1][-1]
417
- numxbins = len(bins[0]) - 1
418
- numybins = len(bins[1]) - 1
419
- cuts = (x >= xstart) & (x < xend) & (y >= ystart) & (y < yend)
420
- c = ((x[cuts] - xstart) / (xend - xstart) * numxbins).astype(np.int_)
421
- c += ((y[cuts] - ystart) / (yend - ystart) * numybins).astype(np.int_) * numxbins
422
- jh = np.bincount(c, minlength=numxbins * numybins).reshape(numxbins, numybins)
423
- else:
424
- jh, xbins, ybins = np.histogram2d(x, y, bins=bins)
425
- if debug:
426
- print(f"{xbins} {ybins}")
427
662
 
428
663
  # smooth the jh with a gaussian filter of given sigma
429
664
  sp.ndimage.gaussian_filter(jh, sigma=sigma, mode="constant", output=jh)
@@ -458,65 +693,96 @@ def mutual_info_2d(
458
693
 
459
694
  # @conditionaljit
460
695
  def cross_mutual_info(
461
- x,
462
- y,
463
- returnaxis=False,
464
- negsteps=-1,
465
- possteps=-1,
466
- locs=None,
467
- Fs=1.0,
468
- norm=True,
469
- madnorm=False,
470
- windowfunc="None",
471
- bins=-1,
472
- prebin=True,
473
- sigma=0.25,
474
- fast=True,
475
- ):
476
- """Calculate cross-mutual information between two 1D arrays.
696
+ x: NDArray[np.floating[Any]],
697
+ y: NDArray[np.floating[Any]],
698
+ returnaxis: bool = False,
699
+ negsteps: int = -1,
700
+ possteps: int = -1,
701
+ locs: Optional[NDArray] = None,
702
+ Fs: float = 1.0,
703
+ norm: bool = True,
704
+ madnorm: bool = False,
705
+ windowfunc: str = "None",
706
+ bins: int = -1,
707
+ prebin: bool = True,
708
+ sigma: float = 0.25,
709
+ fast: bool = True,
710
+ ) -> Union[NDArray, Tuple[NDArray, NDArray, int]]:
711
+ """
712
+ Calculate cross-mutual information between two 1D arrays.
713
+
714
+ This function computes the cross-mutual information (MI) between two signals
715
+ `x` and `y` at various time lags or specified offsets. It supports normalization,
716
+ windowing, and histogram smoothing for robust estimation.
717
+
477
718
  Parameters
478
719
  ----------
479
- x : 1D array
480
- first variable
481
- y : 1D array
482
- second variable. The length of y must by >= the length of x
483
- returnaxis : bool
484
- set to True to return the time axis
485
- negsteps: int
486
- possteps: int
487
- locs : list
488
- a set of offsets at which to calculate the cross mutual information
489
- Fs=1.0,
490
- norm : bool
491
- calculate normalized MI at each offset
492
- madnorm : bool
493
- set to True to normalize cross MI waveform by it's median average deviate
494
- windowfunc : str
495
- name of the window function to apply to input vectors prior to MI calculation
496
- bins : int
497
- number of bins in each dimension of the 2D histogram. Set to -1 to set automatically
498
- prebin : bool
499
- set to true to cache 2D histogram for all offsets
500
- sigma : float
501
- histogram smoothing kernel
502
- fast: bool
503
- apply speed optimizations
720
+ x : NDArray[np.floating[Any]]
721
+ First variable (signal).
722
+ y : NDArray[np.floating[Any]]
723
+ Second variable (signal). Must have length >= length of `x`.
724
+ returnaxis : bool, optional
725
+ If True, return the time axis along with the MI values. Default is False.
726
+ negsteps : int, optional
727
+ Number of negative time steps to compute MI for. If -1, uses default based on signal length.
728
+ Default is -1.
729
+ possteps : int, optional
730
+ Number of positive time steps to compute MI for. If -1, uses default based on signal length.
731
+ Default is -1.
732
+ locs : ndarray of int, optional
733
+ Specific time offsets at which to compute MI. If None, uses `negsteps` and `possteps`.
734
+ Default is None.
735
+ Fs : float, optional
736
+ Sampling frequency. Used when `returnaxis` is True. Default is 1.0.
737
+ norm : bool, optional
738
+ If True, normalize the MI values. Default is True.
739
+ madnorm : bool, optional
740
+ If True, normalize the MI waveform by its median absolute deviation (MAD).
741
+ Default is False.
742
+ windowfunc : str, optional
743
+ Name of the window function to apply to input signals before MI calculation.
744
+ Default is "None".
745
+ bins : int, optional
746
+ Number of bins for the 2D histogram. If -1, automatically determined.
747
+ Default is -1.
748
+ prebin : bool, optional
749
+ If True, precompute and cache the 2D histogram for all offsets.
750
+ Default is True.
751
+ sigma : float, optional
752
+ Standard deviation of the Gaussian smoothing kernel applied to the histogram.
753
+ Default is 0.25.
754
+ fast : bool, optional
755
+ If True, apply speed optimizations. Default is True.
504
756
 
505
757
  Returns
506
758
  -------
507
- if returnaxis is True:
508
- thexmi_x : 1D array
509
- the set of offsets at which cross mutual information is calculated
510
- thexmi_y : 1D array
511
- the set of cross mutual information values
512
- len(thexmi_x): int
513
- the number of cross mutual information values returned
514
- else:
515
- thexmi_y : 1D array
516
- the set of cross mutual information values
759
+ ndarray or tuple of ndarray
760
+ If `returnaxis` is False:
761
+ The set of cross-mutual information values.
762
+ If `returnaxis` is True:
763
+ Tuple of (time_axis, mi_values, num_values), where:
764
+ - time_axis : ndarray of float
765
+ Time axis corresponding to the MI values.
766
+ - mi_values : ndarray of float
767
+ Cross-mutual information values.
768
+ - num_values : int
769
+ Number of MI values returned.
517
770
 
771
+ Notes
772
+ -----
773
+ - The function normalizes input signals using detrending and optional windowing.
774
+ - Cross-mutual information is computed using 2D histogram estimation and
775
+ mutual information calculation.
776
+ - If `prebin` is True, the 2D histogram is precomputed for efficiency.
777
+
778
+ Examples
779
+ --------
780
+ >>> import numpy as np
781
+ >>> x = np.random.randn(100)
782
+ >>> y = np.random.randn(100)
783
+ >>> mi = cross_mutual_info(x, y)
784
+ >>> mi_axis, mi_vals, num = cross_mutual_info(x, y, returnaxis=True, Fs=10)
518
785
  """
519
-
520
786
  normx = tide_math.corrnormalize(x, detrendorder=1, windowfunc=windowfunc)
521
787
  normy = tide_math.corrnormalize(y, detrendorder=1, windowfunc=windowfunc)
522
788
 
@@ -555,32 +821,56 @@ def cross_mutual_info(
555
821
  else:
556
822
  destloc += 1
557
823
  if i < 0:
558
- thexmi_y[destloc] = mutual_info_2d(
559
- normx[: i + len(normy)],
560
- normy[-i:],
561
- bins=bins2d,
562
- normalized=norm,
563
- fast=fast,
564
- sigma=sigma,
565
- )
824
+ if fast:
825
+ thexmi_y[destloc] = mutual_info_2d_fast(
826
+ normx[: i + len(normy)],
827
+ normy[-i:],
828
+ bins2d,
829
+ normalized=norm,
830
+ sigma=sigma,
831
+ )
832
+ else:
833
+ thexmi_y[destloc] = mutual_info_2d(
834
+ normx[: i + len(normy)],
835
+ normy[-i:],
836
+ bins2d,
837
+ normalized=norm,
838
+ sigma=sigma,
839
+ )
566
840
  elif i == 0:
567
- thexmi_y[destloc] = mutual_info_2d(
568
- normx,
569
- normy,
570
- bins=bins2d,
571
- normalized=norm,
572
- fast=fast,
573
- sigma=sigma,
574
- )
841
+ if fast:
842
+ thexmi_y[destloc] = mutual_info_2d_fast(
843
+ normx,
844
+ normy,
845
+ bins2d,
846
+ normalized=norm,
847
+ sigma=sigma,
848
+ )
849
+ else:
850
+ thexmi_y[destloc] = mutual_info_2d(
851
+ normx,
852
+ normy,
853
+ bins2d,
854
+ normalized=norm,
855
+ sigma=sigma,
856
+ )
575
857
  else:
576
- thexmi_y[destloc] = mutual_info_2d(
577
- normx[i:],
578
- normy[: len(normy) - i],
579
- bins=bins2d,
580
- normalized=norm,
581
- fast=fast,
582
- sigma=sigma,
583
- )
858
+ if fast:
859
+ thexmi_y[destloc] = mutual_info_2d_fast(
860
+ normx[i:],
861
+ normy[: len(normy) - i],
862
+ bins2d,
863
+ normalized=norm,
864
+ sigma=sigma,
865
+ )
866
+ else:
867
+ thexmi_y[destloc] = mutual_info_2d(
868
+ normx[i:],
869
+ normy[: len(normy) - i],
870
+ bins2d,
871
+ normalized=norm,
872
+ sigma=sigma,
873
+ )
584
874
 
585
875
  if madnorm:
586
876
  thexmi_y = tide_math.madnormalize(thexmi_y)
@@ -599,77 +889,94 @@ def cross_mutual_info(
599
889
  return thexmi_y
600
890
 
601
891
 
602
- def mutual_info_to_r(themi, d=1):
603
- """Convert mutual information to Pearson product-moment correlation."""
604
- return np.power(1.0 - np.exp(-2.0 * themi / d), -0.5)
892
+ def mutual_info_to_r(themi: float, d: int = 1) -> float:
893
+ """
894
+ Convert mutual information to Pearson product-moment correlation.
605
895
 
896
+ This function transforms mutual information values into Pearson correlation coefficients
897
+ using the relationship derived from the assumption of joint Gaussian distributions.
606
898
 
607
- def dtw_distance(s1, s2):
608
- # Dynamic time warping function written by GPT-4
609
- # Get the lengths of the two input sequences
610
- n, m = len(s1), len(s2)
899
+ Parameters
900
+ ----------
901
+ themi : float
902
+ Mutual information value (in nats) to be converted.
903
+ d : int, default=1
904
+ Dimensionality of the random variables. For single-dimensional variables, d=1.
905
+ For multi-dimensional variables, d represents the number of dimensions.
611
906
 
612
- # Initialize a (n+1) x (m+1) matrix with zeros
613
- DTW = np.zeros((n + 1, m + 1))
907
+ Returns
908
+ -------
909
+ float
910
+ Pearson product-moment correlation coefficient corresponding to the input
911
+ mutual information value. The result is in the range [0, 1].
614
912
 
615
- # Set the first row and first column of the matrix to infinity, since
616
- # the first element of each sequence cannot be aligned with an empty sequence
617
- DTW[1:, 0] = np.inf
618
- DTW[0, 1:] = np.inf
913
+ Notes
914
+ -----
915
+ The transformation is based on the formula:
916
+ r = (1 - exp(-2*MI/d))^(-1/2)
619
917
 
620
- # Compute the DTW distance by iteratively filling in the matrix
621
- for i in range(1, n + 1):
622
- for j in range(1, m + 1):
623
- # Compute the cost of aligning the i-th element of s1 with the j-th element of s2
624
- cost = abs(s1[i - 1] - s2[j - 1])
918
+ This approximation is valid under the assumption that the variables follow
919
+ a joint Gaussian distribution. For non-Gaussian distributions, the relationship
920
+ may not hold exactly.
625
921
 
626
- # Compute the minimum cost of aligning the first i-1 elements of s1 with the first j elements of s2,
627
- # the first i elements of s1 with the first j-1 elements of s2, and the first i-1 elements of s1
628
- # with the first j-1 elements of s2, and add this to the cost of aligning the i-th element of s1
629
- # with the j-th element of s2
630
- DTW[i, j] = cost + np.min([DTW[i - 1, j], DTW[i, j - 1], DTW[i - 1, j - 1]])
922
+ Examples
923
+ --------
924
+ >>> mutual_info_to_r(1.0)
925
+ 0.8416445342422313
631
926
 
632
- # Return the DTW distance between the two sequences, which is the value in the last cell of the matrix
633
- return DTW[n, m]
927
+ >>> mutual_info_to_r(2.0, d=2)
928
+ 0.9640275800758169
929
+ """
930
+ return np.power(1.0 - np.exp(-2.0 * themi / d), -0.5)
634
931
 
635
932
 
636
- def delayedcorr(data1, data2, delayval, timestep):
637
- """Calculate correlation between two 1D arrays, at specific delay.
933
+ def delayedcorr(
934
+ data1: NDArray, data2: NDArray, delayval: float, timestep: float
935
+ ) -> Tuple[float, float]:
936
+ return sp.stats.pearsonr(
937
+ data1, tide_resample.timeshift(data2, delayval / timestep, 30).statistic
938
+ )
638
939
 
639
- Parameters
640
- ----------
641
- data1
642
- data2
643
- delayval
644
- timestep
645
940
 
646
- Returns
647
- -------
648
- corr
941
+ def cepstraldelay(
942
+ data1: NDArray, data2: NDArray, timestep: float, displayplots: bool = True
943
+ ) -> float:
649
944
  """
650
- return sp.stats.pearsonr(data1, tide_resample.timeshift(data2, delayval / timestep, 30)[0])
945
+ Calculate correlation between two datasets with a time delay applied to the second dataset.
651
946
 
652
-
653
- def cepstraldelay(data1, data2, timestep, displayplots=True):
654
- """
655
- Estimate delay between two signals using Choudhary's cepstral analysis method.
947
+ This function computes the Pearson correlation coefficient between two datasets,
948
+ where the second dataset is time-shifted by a specified delay before correlation
949
+ is calculated. The time shift is applied using the tide_resample.timeshift function.
656
950
 
657
951
  Parameters
658
952
  ----------
659
- data1
660
- data2
661
- timestep
662
- displayplots
953
+ data1 : NDArray
954
+ First dataset for correlation calculation.
955
+ data2 : NDArray
956
+ Second dataset to be time-shifted and correlated with data1.
957
+ delayval : float
958
+ Time delay to apply to data2, specified in the same units as timestep.
959
+ timestep : float
960
+ Time step of the datasets, used to convert delayval to sample units.
663
961
 
664
962
  Returns
665
963
  -------
666
- arr
964
+ Tuple[float, float]
965
+ Pearson correlation coefficient and p-value from the correlation test.
667
966
 
668
- References
669
- ----------
670
- * Choudhary, H., Bahl, R. & Kumar, A.
671
- Inter-sensor Time Delay Estimation using cepstrum of sum and difference signals in
672
- underwater multipath environment. in 1-7 (IEEE, 2015). doi:10.1109/UT.2015.7108308
967
+ Notes
968
+ -----
969
+ The delayval is converted to sample units by dividing by timestep before
970
+ applying the time shift. The tide_resample.timeshift function is used internally
971
+ with a window parameter of 30.
972
+
973
+ Examples
974
+ --------
975
+ >>> import numpy as np
976
+ >>> data1 = np.array([1, 2, 3, 4, 5])
977
+ >>> data2 = np.array([2, 3, 4, 5, 6])
978
+ >>> corr, p_value = delayedcorr(data1, data2, delay=1.0, timestep=0.1)
979
+ >>> print(f"Correlation: {corr:.3f}")
673
980
  """
674
981
  ceps1, _ = tide_math.complex_cepstrum(data1)
675
982
  ceps2, _ = tide_math.complex_cepstrum(data2)
@@ -707,24 +1014,74 @@ def cepstraldelay(data1, data2, timestep, displayplots=True):
707
1014
 
708
1015
 
709
1016
  class AliasedCorrelator:
710
- """An aliased correlator.
1017
+ def __init__(self, hiressignal, hires_Fs, numsteps):
1018
+ """
1019
+ Initialize the object with high-resolution signal parameters.
711
1020
 
712
- Parameters
713
- ----------
714
- hiressignal : 1D array
715
- The unaliased waveform to match
716
- hires_Fs : float
717
- The sample rate of the unaliased waveform
718
- numsteps : int
719
- Number of distinct slice acquisition times within the TR.
720
- """
1021
+ Parameters
1022
+ ----------
1023
+ hiressignal : array-like
1024
+ High-resolution signal data to be processed.
1025
+ hires_Fs : float
1026
+ Sampling frequency of the high-resolution signal in Hz.
1027
+ numsteps : int
1028
+ Number of steps for signal processing.
721
1029
 
722
- def __init__(self, hiressignal, hires_Fs, numsteps):
1030
+ Returns
1031
+ -------
1032
+ None
1033
+ This method initializes the object attributes and does not return any value.
1034
+
1035
+ Notes
1036
+ -----
1037
+ This constructor sets up the basic configuration for high-resolution signal processing
1038
+ by storing the sampling frequency and number of steps, then calls sethiressignal()
1039
+ to process the input signal.
1040
+
1041
+ Examples
1042
+ --------
1043
+ >>> obj = MyClass(hiressignal, hires_Fs=44100, numsteps=100)
1044
+ >>> obj.hires_Fs
1045
+ 44100
1046
+ >>> obj.numsteps
1047
+ 100
1048
+ """
723
1049
  self.hires_Fs = hires_Fs
724
1050
  self.numsteps = numsteps
725
1051
  self.sethiressignal(hiressignal)
726
1052
 
727
1053
  def sethiressignal(self, hiressignal):
1054
+ """
1055
+ Set high resolution signal and compute related parameters.
1056
+
1057
+ This method processes the high resolution signal by normalizing it and computing
1058
+ correlation-related parameters including correlation length and correlation x-axis.
1059
+
1060
+ Parameters
1061
+ ----------
1062
+ hiressignal : array-like
1063
+ High resolution signal data to be processed and normalized.
1064
+
1065
+ Returns
1066
+ -------
1067
+ None
1068
+ This method modifies the instance attributes in-place and does not return a value.
1069
+
1070
+ Notes
1071
+ -----
1072
+ The method performs correlation normalization using `tide_math.corrnormalize` and
1073
+ computes the correlation length as `len(self.hiressignal) * 2 + 1`. The correlation
1074
+ x-axis is computed based on the sampling frequency (`self.hires_Fs`) and the length
1075
+ of the high resolution signal.
1076
+
1077
+ Examples
1078
+ --------
1079
+ >>> obj.sethiressignal(hiressignal_data)
1080
+ >>> print(obj.corrlen)
1081
+ 1001
1082
+ >>> print(obj.corrx.shape)
1083
+ (1001,)
1084
+ """
728
1085
  self.hiressignal = tide_math.corrnormalize(hiressignal)
729
1086
  self.corrlen = len(self.hiressignal) * 2 + 1
730
1087
  self.corrx = (
@@ -733,23 +1090,63 @@ class AliasedCorrelator:
733
1090
  )
734
1091
 
735
1092
  def getxaxis(self):
1093
+ """
1094
+ Return the x-axis correction value.
1095
+
1096
+ This method retrieves the correction value applied to the x-axis.
1097
+
1098
+ Returns
1099
+ -------
1100
+ float or int
1101
+ The correction value for the x-axis stored in `self.corrx`.
1102
+
1103
+ Notes
1104
+ -----
1105
+ The returned value represents the x-axis correction that has been
1106
+ previously computed or set in the object's `corrx` attribute.
1107
+
1108
+ Examples
1109
+ --------
1110
+ >>> obj = MyClass()
1111
+ >>> obj.corrx = 5.0
1112
+ >>> obj.getxaxis()
1113
+ 5.0
1114
+ """
736
1115
  return self.corrx
737
1116
 
738
1117
  def apply(self, loressignal, offset, debug=False):
739
- """Apply correlator to aliased waveform.
1118
+ """
1119
+ Apply correlator to aliased waveform.
1120
+
740
1121
  NB: Assumes the highres frequency is an integral multiple of the lowres frequency
1122
+
741
1123
  Parameters
742
1124
  ----------
743
- loressignal: 1D array
1125
+ loressignal : 1D array
744
1126
  The aliased waveform to match
745
- offset: int
1127
+ offset : int
746
1128
  Integer offset to apply to the upsampled lowressignal (to account for slice time offset)
747
- debug: bool, optional
1129
+ debug : bool, optional
748
1130
  Whether to print diagnostic information
1131
+
749
1132
  Returns
750
1133
  -------
751
- corrfunc: 1D array
1134
+ corrfunc : 1D array
752
1135
  The full correlation function
1136
+
1137
+ Notes
1138
+ -----
1139
+ This function applies a correlator to an aliased waveform by:
1140
+ 1. Creating an upsampled version of the high-resolution signal
1141
+ 2. Inserting the low-resolution signal at the specified offset
1142
+ 3. Computing the cross-correlation between the two signals
1143
+ 4. Normalizing the result by the square root of the number of steps
1144
+
1145
+ Examples
1146
+ --------
1147
+ >>> result = correlator.apply(signal, offset=5, debug=True)
1148
+ >>> print(result.shape)
1149
+ (len(highres_signal),)
753
1150
  """
754
1151
  if debug:
755
1152
  print(offset, self.numsteps)
@@ -762,13 +1159,63 @@ class AliasedCorrelator:
762
1159
 
763
1160
 
764
1161
  def matchsamplerates(
765
- input1,
766
- Fs1,
767
- input2,
768
- Fs2,
769
- method="univariate",
770
- debug=False,
771
- ):
1162
+ input1: NDArray,
1163
+ Fs1: float,
1164
+ input2: NDArray,
1165
+ Fs2: float,
1166
+ method: str = "univariate",
1167
+ debug: bool = False,
1168
+ ) -> Tuple[NDArray, NDArray, float]:
1169
+ """
1170
+ Match sampling rates of two input arrays by upsampling the lower sampling rate signal.
1171
+
1172
+ This function takes two input arrays with potentially different sampling rates and
1173
+ ensures they have the same sampling rate by upsampling the signal with the lower
1174
+ sampling rate to match the higher one. The function preserves the original data
1175
+ while adjusting the sampling rate for compatibility.
1176
+
1177
+ Parameters
1178
+ ----------
1179
+ input1 : NDArray
1180
+ First input array to be processed.
1181
+ Fs1 : float
1182
+ Sampling frequency of the first input array (Hz).
1183
+ input2 : NDArray
1184
+ Second input array to be processed.
1185
+ Fs2 : float
1186
+ Sampling frequency of the second input array (Hz).
1187
+ method : str, optional
1188
+ Resampling method to use, by default "univariate".
1189
+ See `tide_resample.upsample` for available methods.
1190
+ debug : bool, optional
1191
+ Enable debug output, by default False.
1192
+
1193
+ Returns
1194
+ -------
1195
+ Tuple[NDArray, NDArray, float]
1196
+ Tuple containing:
1197
+ - matchedinput1: First input array upsampled to match the sampling rate
1198
+ - matchedinput2: Second input array upsampled to match the sampling rate
1199
+ - corrFs: The common sampling frequency used for both outputs
1200
+
1201
+ Notes
1202
+ -----
1203
+ - If sampling rates are equal, no upsampling is performed
1204
+ - The function always upsamples to the higher sampling rate
1205
+ - The upsampling is performed using the `tide_resample.upsample` function
1206
+ - Both output arrays will have the same length and sampling rate
1207
+
1208
+ Examples
1209
+ --------
1210
+ >>> import numpy as np
1211
+ >>> input1 = np.array([1, 2, 3, 4])
1212
+ >>> input2 = np.array([5, 6, 7])
1213
+ >>> Fs1 = 10.0
1214
+ >>> Fs2 = 5.0
1215
+ >>> matched1, matched2, common_fs = matchsamplerates(input1, Fs1, input2, Fs2)
1216
+ >>> print(common_fs)
1217
+ 10.0
1218
+ """
772
1219
  if Fs1 > Fs2:
773
1220
  corrFs = Fs1
774
1221
  matchedinput1 = input1
@@ -785,16 +1232,74 @@ def matchsamplerates(
785
1232
 
786
1233
 
787
1234
  def arbcorr(
788
- input1,
789
- Fs1,
790
- input2,
791
- Fs2,
792
- start1=0.0,
793
- start2=0.0,
794
- windowfunc="hamming",
795
- method="univariate",
796
- debug=False,
797
- ):
1235
+ input1: NDArray,
1236
+ Fs1: float,
1237
+ input2: NDArray,
1238
+ Fs2: float,
1239
+ start1: float = 0.0,
1240
+ start2: float = 0.0,
1241
+ windowfunc: str = "hamming",
1242
+ method: str = "univariate",
1243
+ debug: bool = False,
1244
+ ) -> Tuple[NDArray, NDArray, float, int]:
1245
+ """
1246
+ Compute the cross-correlation between two signals with arbitrary sampling rates.
1247
+
1248
+ This function performs cross-correlation between two input signals after
1249
+ matching their sampling rates. It applies normalization and uses FFT-based
1250
+ convolution for efficient computation. The result includes the time lag axis,
1251
+ cross-correlation values, the matched sampling frequency, and the index of
1252
+ the zero-lag point.
1253
+
1254
+ Parameters
1255
+ ----------
1256
+ input1 : NDArray
1257
+ First input signal array.
1258
+ Fs1 : float
1259
+ Sampling frequency of the first signal (Hz).
1260
+ input2 : NDArray
1261
+ Second input signal array.
1262
+ Fs2 : float
1263
+ Sampling frequency of the second signal (Hz).
1264
+ start1 : float, optional
1265
+ Start time of the first signal (default is 0.0).
1266
+ start2 : float, optional
1267
+ Start time of the second signal (default is 0.0).
1268
+ windowfunc : str, optional
1269
+ Window function used for normalization (default is "hamming").
1270
+ method : str, optional
1271
+ Method used for matching sampling rates (default is "univariate").
1272
+ debug : bool, optional
1273
+ If True, enables debug logging (default is False).
1274
+
1275
+ Returns
1276
+ -------
1277
+ tuple
1278
+ A tuple containing:
1279
+ - thexcorr_x : NDArray
1280
+ Time lag axis for the cross-correlation (seconds).
1281
+ - thexcorr_y : NDArray
1282
+ Cross-correlation values.
1283
+ - corrFs : float
1284
+ Matched sampling frequency used for the computation (Hz).
1285
+ - zeroloc : int
1286
+ Index corresponding to the zero-lag point in the cross-correlation.
1287
+
1288
+ Notes
1289
+ -----
1290
+ - The function upsamples the signals to the higher of the two sampling rates.
1291
+ - Normalization is applied using a detrend order of 1 and the specified window function.
1292
+ - The cross-correlation is computed using FFT convolution for efficiency.
1293
+ - The zero-lag point is determined as the index of the minimum absolute value in the time axis.
1294
+
1295
+ Examples
1296
+ --------
1297
+ >>> import numpy as np
1298
+ >>> signal1 = np.random.randn(1000)
1299
+ >>> signal2 = np.random.randn(1000)
1300
+ >>> lags, corr_vals, fs, zero_idx = arbcorr(signal1, 10.0, signal2, 10.0)
1301
+ >>> print(f"Zero-lag index: {zero_idx}")
1302
+ """
798
1303
  # upsample to the higher frequency of the two
799
1304
  matchedinput1, matchedinput2, corrFs = matchsamplerates(
800
1305
  input1,
@@ -822,13 +1327,66 @@ def arbcorr(
822
1327
 
823
1328
 
824
1329
  def faststcorrelate(
825
- input1, input2, windowtype="hann", nperseg=32, weighting="None", displayplots=False
826
- ):
827
- """Perform correlation between short-time Fourier transformed arrays."""
1330
+ input1: NDArray,
1331
+ input2: NDArray,
1332
+ windowtype: str = "hann",
1333
+ nperseg: int = 32,
1334
+ weighting: str = "None",
1335
+ displayplots: bool = False,
1336
+ ) -> Tuple[NDArray, NDArray, NDArray]:
1337
+ """
1338
+ Perform correlation between short-time Fourier transformed arrays.
1339
+
1340
+ This function computes the short-time cross-correlation between two input signals
1341
+ using their short-time Fourier transforms (STFTs). It applies a windowing function
1342
+ to each signal, computes the STFT, and then performs correlation in the frequency
1343
+ domain before inverse transforming back to the time domain. The result is normalized
1344
+ by the auto-correlation of each signal.
1345
+
1346
+ Parameters
1347
+ ----------
1348
+ input1 : ndarray
1349
+ First input signal array.
1350
+ input2 : ndarray
1351
+ Second input signal array.
1352
+ windowtype : str, optional
1353
+ Type of window to apply. Default is 'hann'.
1354
+ nperseg : int, optional
1355
+ Length of each segment for STFT. Default is 32.
1356
+ weighting : str, optional
1357
+ Weighting method for the STFT. Default is 'None'.
1358
+ displayplots : bool, optional
1359
+ If True, display plots (not implemented in current version). Default is False.
1360
+
1361
+ Returns
1362
+ -------
1363
+ corrtimes : ndarray
1364
+ Time shifts corresponding to the correlation results.
1365
+ times : ndarray
1366
+ Time indices of the STFT.
1367
+ stcorr : ndarray
1368
+ Short-time cross-correlation values.
1369
+
1370
+ Notes
1371
+ -----
1372
+ The function uses `scipy.signal.stft` to compute the short-time Fourier transform
1373
+ of both input signals. The correlation is computed in the frequency domain and
1374
+ normalized by the square root of the auto-correlation of each signal.
1375
+
1376
+ Examples
1377
+ --------
1378
+ >>> import numpy as np
1379
+ >>> from scipy import signal
1380
+ >>> t = np.linspace(0, 1, 100)
1381
+ >>> x1 = np.sin(2 * np.pi * 5 * t)
1382
+ >>> x2 = np.sin(2 * np.pi * 5 * t + 0.1)
1383
+ >>> corrtimes, times, corr = faststcorrelate(x1, x2)
1384
+ >>> print(corr.shape)
1385
+ (32, 100)
1386
+ """
828
1387
  nfft = nperseg
829
1388
  noverlap = nperseg - 1
830
1389
  onesided = False
831
- boundary = "even"
832
1390
  freqs, times, thestft1 = signal.stft(
833
1391
  input1,
834
1392
  fs=1.0,
@@ -838,7 +1396,7 @@ def faststcorrelate(
838
1396
  nfft=nfft,
839
1397
  detrend="linear",
840
1398
  return_onesided=onesided,
841
- boundary=boundary,
1399
+ boundary="even",
842
1400
  padded=True,
843
1401
  axis=-1,
844
1402
  )
@@ -852,7 +1410,7 @@ def faststcorrelate(
852
1410
  nfft=nfft,
853
1411
  detrend="linear",
854
1412
  return_onesided=onesided,
855
- boundary=boundary,
1413
+ boundary="even",
856
1414
  padded=True,
857
1415
  axis=-1,
858
1416
  )
@@ -878,7 +1436,40 @@ def faststcorrelate(
878
1436
  return corrtimes, times, stcorr
879
1437
 
880
1438
 
881
- def primefacs(thelen):
1439
+ def primefacs(thelen: int) -> list:
1440
+ """
1441
+ Compute the prime factorization of a given integer.
1442
+
1443
+ Parameters
1444
+ ----------
1445
+ thelen : int
1446
+ The positive integer to factorize. Must be greater than 0.
1447
+
1448
+ Returns
1449
+ -------
1450
+ list
1451
+ A list of prime factors of `thelen`, sorted in ascending order.
1452
+ Each factor appears as many times as its multiplicity in the
1453
+ prime factorization.
1454
+
1455
+ Notes
1456
+ -----
1457
+ This function implements trial division algorithm to find prime factors.
1458
+ The algorithm starts with the smallest prime (2) and continues with
1459
+ increasing integers until the square root of the remaining number.
1460
+ The final remaining number (if greater than 1) is also a prime factor.
1461
+
1462
+ Examples
1463
+ --------
1464
+ >>> primefacs(12)
1465
+ [2, 2, 3]
1466
+
1467
+ >>> primefacs(17)
1468
+ [17]
1469
+
1470
+ >>> primefacs(100)
1471
+ [2, 2, 5, 5]
1472
+ """
882
1473
  i = 2
883
1474
  factors = []
884
1475
  while i * i <= thelen:
@@ -891,40 +1482,99 @@ def primefacs(thelen):
891
1482
  return factors
892
1483
 
893
1484
 
894
- def optfftlen(thelen):
1485
+ def optfftlen(thelen: int) -> int:
1486
+ """
1487
+ Calculate optimal FFT length for given input length.
1488
+
1489
+ This function currently returns the input length as-is, but is designed
1490
+ to be extended for optimal FFT length calculation based on hardware
1491
+ constraints or performance considerations.
1492
+
1493
+ Parameters
1494
+ ----------
1495
+ thelen : int
1496
+ The input length for which to calculate optimal FFT length.
1497
+ Must be a positive integer.
1498
+
1499
+ Returns
1500
+ -------
1501
+ int
1502
+ The optimal FFT length. For the current implementation, this
1503
+ simply returns the input `thelen` value.
1504
+
1505
+ Notes
1506
+ -----
1507
+ In a more complete implementation, this function would calculate
1508
+ the optimal FFT length by finding the smallest number >= thelen
1509
+ that has only small prime factors (2, 3, 5, 7) for optimal
1510
+ performance on most FFT implementations.
1511
+
1512
+ Examples
1513
+ --------
1514
+ >>> optfftlen(1024)
1515
+ 1024
1516
+ >>> optfftlen(1000)
1517
+ 1000
1518
+ """
895
1519
  return thelen
896
1520
 
897
1521
 
898
1522
  def fastcorrelate(
899
- input1,
900
- input2,
901
- usefft=True,
902
- zeropadding=0,
903
- weighting="None",
904
- compress=False,
905
- displayplots=False,
906
- debug=False,
907
- ):
908
- """Perform a fast correlation between two arrays.
1523
+ input1: NDArray,
1524
+ input2: NDArray,
1525
+ usefft: bool = True,
1526
+ zeropadding: int = 0,
1527
+ weighting: str = "None",
1528
+ compress: bool = False,
1529
+ displayplots: bool = False,
1530
+ debug: bool = False,
1531
+ ) -> NDArray:
1532
+ """
1533
+ Perform a fast correlation between two arrays.
1534
+
1535
+ This function computes the cross-correlation of two input arrays, with options
1536
+ for using FFT-based convolution or direct correlation, as well as padding and
1537
+ weighting schemes.
909
1538
 
910
1539
  Parameters
911
1540
  ----------
912
- input1
913
- input2
914
- usefft
915
- zeropadding
916
- weighting
917
- compress
918
- displayplots
919
- debug
1541
+ input1 : ndarray
1542
+ First input array to correlate.
1543
+ input2 : ndarray
1544
+ Second input array to correlate.
1545
+ usefft : bool, optional
1546
+ If True, use FFT-based convolution for faster computation. Default is True.
1547
+ zeropadding : int, optional
1548
+ Zero-padding length. If 0, no padding is applied. If negative, automatic
1549
+ padding is applied. If positive, explicit padding is applied. Default is 0.
1550
+ weighting : str, optional
1551
+ Type of weighting to apply. If "None", no weighting is applied. Default is "None".
1552
+ compress : bool, optional
1553
+ If True and `weighting` is not "None", compress the result. Default is False.
1554
+ displayplots : bool, optional
1555
+ If True, display plots of padded inputs and correlation result. Default is False.
1556
+ debug : bool, optional
1557
+ If True, enable debug output. Default is False.
920
1558
 
921
1559
  Returns
922
1560
  -------
923
- corr
1561
+ ndarray
1562
+ The cross-correlation of `input1` and `input2`. The length of the output is
1563
+ `len(input1) + len(input2) - 1`.
924
1564
 
925
1565
  Notes
926
1566
  -----
927
- From http://stackoverflow.com/questions/12323959/fast-cross-correlation-method-in-python.
1567
+ This implementation is based on the method described at:
1568
+ http://stackoverflow.com/questions/12323959/fast-cross-correlation-method-in-python
1569
+
1570
+ Examples
1571
+ --------
1572
+ >>> import numpy as np
1573
+ >>> a = np.array([1, 2, 3])
1574
+ >>> b = np.array([0, 1, 0])
1575
+ >>> result = fastcorrelate(a, b)
1576
+ >>> print(result)
1577
+ [0. 1. 2. 3. 0.]
928
1578
  """
929
1579
  len1 = len(input1)
930
1580
  len2 = len(input2)
@@ -995,17 +1645,36 @@ def fastcorrelate(
995
1645
  return np.correlate(paddedinput1, paddedinput2, mode="full")
996
1646
 
997
1647
 
998
- def _centered(arr, newsize):
999
- """Return the center newsize portion of the array.
1648
+ def _centered(arr: NDArray, newsize: Union[int, NDArray]) -> NDArray:
1649
+ """
1650
+ Extract a centered subset of an array.
1000
1651
 
1001
1652
  Parameters
1002
1653
  ----------
1003
- arr
1004
- newsize
1654
+ arr : array_like
1655
+ Input array from which to extract the centered subset.
1656
+ newsize : int or array_like
1657
+ The size of the output array. If int, the same size is used for all dimensions.
1658
+ If array_like, specifies the size for each dimension.
1005
1659
 
1006
1660
  Returns
1007
1661
  -------
1008
- arr
1662
+ ndarray
1663
+ Centered subset of the input array with the specified size.
1664
+
1665
+ Notes
1666
+ -----
1667
+ The function extracts a subset from the center of the input array. If the requested
1668
+ size is larger than the current array size in any dimension, the result will be
1669
+ padded with zeros (or the array will be truncated from the center).
1670
+
1671
+ Examples
1672
+ --------
1673
+ >>> import numpy as np
1674
+ >>> arr = np.arange(24).reshape(4, 6)
1675
+ >>> _centered(arr, (2, 3))
1676
+ array([[ 7, 8, 9],
1677
+ [13, 14, 15]])
1009
1678
  """
1010
1679
  newsize = np.asarray(newsize)
1011
1680
  currsize = np.array(arr.shape)
@@ -1015,22 +1684,36 @@ def _centered(arr, newsize):
1015
1684
  return arr[tuple(myslice)]
1016
1685
 
1017
1686
 
1018
- def _check_valid_mode_shapes(shape1, shape2):
1019
- """Check that two shapes are 'valid' with respect to one another.
1020
-
1021
- Specifically, this checks that each item in one tuple is larger than or
1022
- equal to corresponding item in another tuple.
1687
+ def _check_valid_mode_shapes(shape1: Tuple, shape2: Tuple) -> None:
1688
+ """
1689
+ Check that shape1 is valid for 'valid' mode convolution with shape2.
1023
1690
 
1024
1691
  Parameters
1025
1692
  ----------
1026
- shape1
1027
- shape2
1028
-
1029
- Raises
1030
- ------
1031
- ValueError
1032
- If at least one item in the first shape is not larger than or equal to
1033
- the corresponding item in the second one.
1693
+ shape1 : Tuple
1694
+ First shape tuple to compare
1695
+ shape2 : Tuple
1696
+ Second shape tuple to compare
1697
+
1698
+ Returns
1699
+ -------
1700
+ None
1701
+ This function does not return anything but raises ValueError if condition is not met
1702
+
1703
+ Notes
1704
+ -----
1705
+ This function is used to validate that the first shape has at least as many
1706
+ elements as the second shape in every dimension, which is required for
1707
+ 'valid' mode convolution operations.
1708
+
1709
+ Examples
1710
+ --------
1711
+ >>> _check_valid_mode_shapes((10, 10), (5, 5))
1712
+ >>> _check_valid_mode_shapes((10, 10), (10, 5))
1713
+ >>> _check_valid_mode_shapes((5, 5), (10, 5))
1714
+ Traceback (most recent call last):
1715
+ ...
1716
+ ValueError: in1 should have at least as many items as in2 in every dimension for 'valid' mode.
1034
1717
  """
1035
1718
  for d1, d2 in zip(shape1, shape2):
1036
1719
  if not d1 >= d2:
@@ -1041,22 +1724,29 @@ def _check_valid_mode_shapes(shape1, shape2):
1041
1724
 
1042
1725
 
1043
1726
  def convolve_weighted_fft(
1044
- in1, in2, mode="full", weighting="None", compress=False, displayplots=False
1045
- ):
1046
- """Convolve two N-dimensional arrays using FFT.
1727
+ in1: NDArray[np.floating[Any]],
1728
+ in2: NDArray[np.floating[Any]],
1729
+ mode: str = "full",
1730
+ weighting: str = "None",
1731
+ compress: bool = False,
1732
+ displayplots: bool = False,
1733
+ ) -> NDArray[np.floating[Any]]:
1734
+ """
1735
+ Convolve two N-dimensional arrays using FFT with optional weighting.
1047
1736
 
1048
1737
  Convolve `in1` and `in2` using the fast Fourier transform method, with
1049
- the output size determined by the `mode` argument.
1050
- This is generally much faster than `convolve` for large arrays (n > ~500),
1051
- but can be slower when only a few output values are needed, and can only
1052
- output float arrays (int or object array inputs will be cast to float).
1738
+ the output size determined by the `mode` argument. This is generally much
1739
+ faster than `convolve` for large arrays (n > ~500), but can be slower when
1740
+ only a few output values are needed. The function supports both real and
1741
+ complex inputs, and allows for optional weighting and compression of the
1742
+ FFT operations.
1053
1743
 
1054
1744
  Parameters
1055
1745
  ----------
1056
- in1 : array_like
1057
- First input.
1058
- in2 : array_like
1059
- Second input. Should have the same number of dimensions as `in1`;
1746
+ in1 : NDArray[np.floating[Any]]
1747
+ First input array.
1748
+ in2 : NDArray[np.floating[Any]]
1749
+ Second input array. Should have the same number of dimensions as `in1`;
1060
1750
  if sizes of `in1` and `in2` are not equal then `in1` has to be the
1061
1751
  larger array.
1062
1752
  mode : str {'full', 'valid', 'same'}, optional
@@ -1072,19 +1762,45 @@ def convolve_weighted_fft(
1072
1762
  ``same``
1073
1763
  The output is the same size as `in1`, centered
1074
1764
  with respect to the 'full' output.
1765
+ weighting : str, optional
1766
+ Type of weighting to apply during convolution. Default is "None".
1767
+ Other options may include "uniform", "gaussian", etc., depending on
1768
+ implementation of `gccproduct`.
1769
+ compress : bool, optional
1770
+ If True, compress the FFT data during computation. Default is False.
1771
+ displayplots : bool, optional
1772
+ If True, display intermediate plots during computation. Default is False.
1075
1773
 
1076
1774
  Returns
1077
1775
  -------
1078
- out : array
1776
+ out : NDArray[np.floating[Any]]
1079
1777
  An N-dimensional array containing a subset of the discrete linear
1080
- convolution of `in1` with `in2`.
1081
- """
1082
- in1 = np.asarray(in1)
1083
- in2 = np.asarray(in2)
1778
+ convolution of `in1` with `in2`. The shape of the output depends on
1779
+ the `mode` parameter.
1084
1780
 
1085
- if np.isscalar(in1) and np.isscalar(in2): # scalar inputs
1086
- return in1 * in2
1087
- elif not in1.ndim == in2.ndim:
1781
+ Notes
1782
+ -----
1783
+ - This function uses real FFT (`rfftn`) for real inputs and standard FFT
1784
+ (`fftpack.fftn`) for complex inputs.
1785
+ - The convolution is computed in the frequency domain using the product
1786
+ of FFTs of the inputs.
1787
+ - For real inputs, the result is scaled to preserve the maximum amplitude.
1788
+ - The `gccproduct` function is used internally to compute the product
1789
+ of the FFTs with optional weighting.
1790
+
1791
+ Examples
1792
+ --------
1793
+ >>> import numpy as np
1794
+ >>> a = np.array([[1, 2], [3, 4]])
1795
+ >>> b = np.array([[1, 0], [0, 1]])
1796
+ >>> result = convolve_weighted_fft(a, b)
1797
+ >>> print(result)
1798
+ [[1. 2.]
1799
+ [3. 4.]]
1800
+ """
1801
+ # if np.isscalar(in1) and np.isscalar(in2): # scalar inputs
1802
+ # return in1 * in2
1803
+ if not in1.ndim == in2.ndim:
1088
1804
  raise ValueError("in1 and in2 should have the same rank")
1089
1805
  elif in1.size == 0 or in2.size == 0: # empty arrays
1090
1806
  return np.array([])
@@ -1128,27 +1844,73 @@ def convolve_weighted_fft(
1128
1844
  # scale to preserve the maximum
1129
1845
 
1130
1846
  if mode == "full":
1131
- return ret
1847
+ retval = ret
1132
1848
  elif mode == "same":
1133
- return _centered(ret, s1)
1849
+ retval = _centered(ret, s1)
1134
1850
  elif mode == "valid":
1135
- return _centered(ret, s1 - s2 + 1)
1851
+ retval = _centered(ret, s1 - s2 + 1)
1136
1852
 
1853
+ return retval
1854
+
1855
+
1856
+ def gccproduct(
1857
+ fft1: NDArray,
1858
+ fft2: NDArray,
1859
+ weighting: str,
1860
+ threshfrac: float = 0.1,
1861
+ compress: bool = False,
1862
+ displayplots: bool = False,
1863
+ ) -> NDArray:
1864
+ """
1865
+ Compute the generalized cross-correlation (GCC) product with optional weighting.
1137
1866
 
1138
- def gccproduct(fft1, fft2, weighting, threshfrac=0.1, compress=False, displayplots=False):
1139
- """Calculate product for generalized crosscorrelation.
1867
+ This function computes the GCC product of two FFT arrays, applying a specified
1868
+ weighting scheme to enhance correlation performance. It supports several weighting
1869
+ methods including 'liang', 'eckart', 'phat', and 'regressor'. The result can be
1870
+ thresholded and optionally compressed to improve visualization and reduce noise.
1140
1871
 
1141
1872
  Parameters
1142
1873
  ----------
1143
- fft1
1144
- fft2
1145
- weighting
1146
- threshfrac
1147
- displayplots
1874
+ fft1 : NDArray
1875
+ First FFT array (complex-valued).
1876
+ fft2 : NDArray
1877
+ Second FFT array (complex-valued).
1878
+ weighting : str
1879
+ Weighting method to apply. Options are:
1880
+ - 'liang': Liang weighting
1881
+ - 'eckart': Eckart weighting
1882
+ - 'phat': PHAT (Phase Transform) weighting
1883
+ - 'regressor': Regressor-based weighting (uses fft2 as reference)
1884
+ - 'None': No weighting applied.
1885
+ threshfrac : float, optional
1886
+ Threshold fraction used to determine the minimum value for output masking.
1887
+ Default is 0.1.
1888
+ compress : bool, optional
1889
+ If True, compress the weighting function using 10th and 90th percentiles.
1890
+ Default is False.
1891
+ displayplots : bool, optional
1892
+ If True, display the reciprocal weighting function as a plot.
1893
+ Default is False.
1148
1894
 
1149
1895
  Returns
1150
1896
  -------
1151
- product
1897
+ NDArray
1898
+ The weighted GCC product. The output is of the same shape as the input arrays.
1899
+ If `weighting` is 'None', the raw product is returned.
1900
+ If `threshfrac` is 0, a zero array of the same shape is returned.
1901
+
1902
+ Notes
1903
+ -----
1904
+ The weighting functions are applied element-wise and are designed to suppress
1905
+ noise and enhance correlation peaks. The 'phat' weighting is commonly used in
1906
+ speech and signal processing due to its robustness.
1907
+
1908
+ Examples
1909
+ --------
1910
+ >>> import numpy as np
1911
+ >>> fft1 = np.random.rand(100) + 1j * np.random.rand(100)
1912
+ >>> fft2 = np.random.rand(100) + 1j * np.random.rand(100)
1913
+ >>> result = gccproduct(fft1, fft2, weighting='phat', threshfrac=0.05)
1152
1914
  """
1153
1915
  product = fft1 * fft2
1154
1916
  if weighting == "None":
@@ -1197,17 +1959,75 @@ def gccproduct(fft1, fft2, weighting, threshfrac=0.1, compress=False, displayplo
1197
1959
 
1198
1960
 
1199
1961
  def aligntcwithref(
1200
- fixedtc,
1201
- movingtc,
1202
- Fs,
1203
- lagmin=-30,
1204
- lagmax=30,
1205
- refine=True,
1206
- zerooutbadfit=False,
1207
- widthmax=1000.0,
1208
- display=False,
1209
- verbose=False,
1210
- ):
1962
+ fixedtc: NDArray,
1963
+ movingtc: NDArray,
1964
+ Fs: float,
1965
+ lagmin: float = -30,
1966
+ lagmax: float = 30,
1967
+ refine: bool = True,
1968
+ zerooutbadfit: bool = False,
1969
+ widthmax: float = 1000.0,
1970
+ display: bool = False,
1971
+ verbose: bool = False,
1972
+ ) -> Tuple[NDArray, float, float, int]:
1973
+ """
1974
+ Align a moving timecourse to a fixed reference timecourse using cross-correlation.
1975
+
1976
+ This function computes the cross-correlation between two timecourses and finds the
1977
+ optimal time lag that maximizes their similarity. The moving timecourse is then
1978
+ aligned to the fixed one using this lag.
1979
+
1980
+ Parameters
1981
+ ----------
1982
+ fixedtc : ndarray
1983
+ The reference timecourse to which the moving timecourse will be aligned.
1984
+ movingtc : ndarray
1985
+ The timecourse to be aligned to the fixed timecourse.
1986
+ Fs : float
1987
+ Sampling frequency of the timecourses in Hz.
1988
+ lagmin : float, optional
1989
+ Minimum lag to consider in seconds. Default is -30.
1990
+ lagmax : float, optional
1991
+ Maximum lag to consider in seconds. Default is 30.
1992
+ refine : bool, optional
1993
+ If True, refine the lag estimate using Gaussian fitting. Default is True.
1994
+ zerooutbadfit : bool, optional
1995
+ If True, zero out the cross-correlation values for bad fits. Default is False.
1996
+ widthmax : float, optional
1997
+ Maximum allowed width of the Gaussian fit in samples. Default is 1000.0.
1998
+ display : bool, optional
1999
+ If True, display plots of the cross-correlation and aligned timecourses. Default is False.
2000
+ verbose : bool, optional
2001
+ If True, print detailed information about the cross-correlation results. Default is False.
2002
+
2003
+ Returns
2004
+ -------
2005
+ tuple
2006
+ A tuple containing:
2007
+ - aligneddata : ndarray
2008
+ The moving timecourse aligned to the fixed timecourse.
2009
+ - maxdelay : float
2010
+ The estimated time lag (in seconds) that maximizes cross-correlation.
2011
+ - maxval : float
2012
+ The maximum cross-correlation value.
2013
+ - failreason : int
2014
+ Reason for failure (0 = success, other values indicate specific failure types).
2015
+
2016
+ Notes
2017
+ -----
2018
+ This function uses `fastcorrelate` for efficient cross-correlation computation and
2019
+ `tide_fit.findmaxlag_gauss` to estimate the optimal lag with optional Gaussian refinement.
2020
+ The alignment is performed using `tide_resample.doresample`.
2021
+
2022
+ Examples
2023
+ --------
2024
+ >>> import numpy as np
2025
+ >>> from typing import Tuple
2026
+ >>> fixed = np.random.rand(1000)
2027
+ >>> moving = np.roll(fixed, 10) # shift by 10 samples
2028
+ >>> aligned, delay, corr, fail = aligntcwithref(fixed, moving, Fs=100)
2029
+ >>> print(f"Estimated delay: {delay}s")
2030
+ """
1211
2031
  # now fixedtc and 2 are on the same timescales
1212
2032
  thexcorr = fastcorrelate(tide_math.corrnormalize(fixedtc), tide_math.corrnormalize(movingtc))
1213
2033
  xcorrlen = len(thexcorr)