paradigma 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,12 @@
1
1
  import numpy as np
2
2
  from scipy import signal
3
- from typing import Tuple
4
3
 
5
4
  from paradigma.config import PulseRateConfig
6
5
 
7
6
 
8
7
  def assign_sqa_label(
9
- ppg_prob: np.ndarray,
10
- config: PulseRateConfig,
11
- acc_label=None
12
- ) -> np.ndarray:
8
+ ppg_prob: np.ndarray, config: PulseRateConfig, acc_label: np.ndarray | None = None
9
+ ) -> np.ndarray:
13
10
  """
14
11
  Assigns a signal quality label to every individual data point.
15
12
 
@@ -43,12 +40,13 @@ def assign_sqa_label(
43
40
 
44
41
  for i in range(n_samples):
45
42
  # Start and end indices for current epoch
46
- start_idx = max(0, int((i - (samples_per_epoch - samples_shift)) // fs)) # max to handle first epochs
47
- end_idx = min(int(i // fs), len(ppg_prob)) # min to handle last epochs
43
+ # Max to handle first epochs
44
+ start_idx = max(0, int((i - (samples_per_epoch - samples_shift)) // fs))
45
+ end_idx = min(int(i // fs), len(ppg_prob)) # min to handle last epochs
48
46
 
49
47
  # Extract probabilities and labels for the current epoch
50
- prob = ppg_prob[start_idx:end_idx+1]
51
- label_imu = acc_label[start_idx:end_idx+1]
48
+ prob = ppg_prob[start_idx : end_idx + 1]
49
+ label_imu = acc_label[start_idx : end_idx + 1]
52
50
 
53
51
  # Calculate mean probability and majority voting for labels
54
52
  data_prob[i] = np.mean(prob)
@@ -61,7 +59,9 @@ def assign_sqa_label(
61
59
  return sqa_label
62
60
 
63
61
 
64
- def extract_pr_segments(sqa_label: np.ndarray, min_pr_samples: int) -> Tuple[np.ndarray, np.ndarray]:
62
+ def extract_pr_segments(
63
+ sqa_label: np.ndarray, min_pr_samples: int
64
+ ) -> tuple[np.ndarray, np.ndarray]:
65
65
  """
66
66
  Extracts pulse rate segments based on the SQA label.
67
67
 
@@ -95,28 +95,27 @@ def extract_pr_segments(sqa_label: np.ndarray, min_pr_samples: int) -> Tuple[np.
95
95
 
96
96
 
97
97
  def extract_pr_from_segment(
98
- ppg: np.ndarray,
99
- tfd_length: int,
100
- fs: int,
101
- kern_type: str,
102
- kern_params: dict
103
- ) -> np.ndarray:
98
+ ppg: np.ndarray, tfd_length: int, fs: int, kern_type: str, kern_params: dict
99
+ ) -> np.ndarray:
104
100
  """
105
101
  Extracts pulse rate from the time-frequency distribution of the PPG signal.
106
102
 
107
103
  Parameters
108
104
  ----------
109
105
  ppg : np.ndarray
110
- The preprocessed PPG segment with 2 seconds of padding on both sides to reduce boundary effects.
106
+ The preprocessed PPG segment with 2 seconds of padding on both sides
107
+ to reduce boundary effects.
111
108
  tfd_length : int
112
- Length of each segment (in seconds) to calculate the time-frequency distribution.
109
+ Length of each segment (in seconds) to calculate the time-frequency
110
+ distribution.
113
111
  fs : int
114
112
  The sampling frequency of the PPG signal.
115
113
  kern_type : str
116
- Type of TFD kernel to use (e.g., 'wvd' for Wigner-Ville distribution).
114
+ Type of TFD kernel to use (e.g., 'wvd' for Wigner-Ville
115
+ distribution).
117
116
  kern_params : dict
118
- Parameters for the specified kernel. Not required for 'wvd', but relevant for other
119
- kernels like 'spwvd' or 'swvd'. Default is None.
117
+ Parameters for the specified kernel. Not required for 'wvd', but
118
+ relevant for other kernels like 'spwvd' or 'swvd'. Default is None.
120
119
 
121
120
  Returns
122
121
  -------
@@ -149,20 +148,17 @@ def extract_pr_from_segment(
149
148
  for segment in ppg_segments:
150
149
  # Calculate the time-frequency distribution
151
150
  pr_tfd = extract_pr_with_tfd(segment, fs, kern_type, kern_params)
152
- pr_est_from_ppg = np.concatenate((pr_est_from_ppg, pr_tfd))
151
+ pr_est_from_ppg = np.concatenate((pr_est_from_ppg, pr_tfd))
153
152
 
154
153
  return pr_est_from_ppg
155
154
 
156
155
 
157
156
  def extract_pr_with_tfd(
158
- ppg: np.ndarray,
159
- fs: int,
160
- kern_type: str,
161
- kern_params: dict
162
- ) -> np.ndarray:
157
+ ppg: np.ndarray, fs: int, kern_type: str, kern_params: dict
158
+ ) -> np.ndarray:
163
159
  """
164
- Estimate pulse rate (PR) from a PPG segment using a TFD method with optional
165
- moving average filtering.
160
+ Estimate pulse rate (PR) from a PPG segment using a TFD method with
161
+ optional moving average filtering.
166
162
 
167
163
  Parameters
168
164
  ----------
@@ -178,7 +174,8 @@ def extract_pr_with_tfd(
178
174
  Returns
179
175
  -------
180
176
  pr_smooth_tfd : np.ndarray
181
- Estimated pr values (in beats per minute) for each 2-second segment of the PPG signal.
177
+ Estimated pr values (in beats per minute) for each 2-second segment
178
+ of the PPG signal.
182
179
  """
183
180
  # Generate the TFD matrix using the specified kernel
184
181
  tfd_obj = TimeFreqDistr()
@@ -193,10 +190,14 @@ def extract_pr_with_tfd(
193
190
  max_freq_indices = np.argmax(tfd, axis=0)
194
191
 
195
192
  pr_smooth_tfd = np.array([])
196
- for i in range(2, int(len(ppg) / fs) - 4 + 1, 2): # Skip the first and last 2 seconds, add 1 to include the last segment
193
+ for i in range(
194
+ 2, int(len(ppg) / fs) - 4 + 1, 2
195
+ ): # Skip the first and last 2 seconds, add 1 to include the last segment
197
196
  relevant_indices = (time_axis >= i) & (time_axis < i + 2)
198
197
  avg_frequency = np.mean(freq_axis[max_freq_indices[relevant_indices]])
199
- pr_smooth_tfd = np.concatenate((pr_smooth_tfd, [60 * avg_frequency])) # Convert frequency to BPM
198
+ pr_smooth_tfd = np.concatenate(
199
+ (pr_smooth_tfd, [60 * avg_frequency])
200
+ ) # Convert frequency to BPM
200
201
 
201
202
  return pr_smooth_tfd
202
203
 
@@ -204,33 +205,46 @@ def extract_pr_with_tfd(
204
205
  class TimeFreqDistr:
205
206
  def __init__(self):
206
207
  """
207
- This module contains the implementation of the Generalized Time-Frequency Distribution (TFD) computation using non-separable kernels.
208
- This is a Python implementation of the MATLAB code provided by John O Toole in the following repository: https://github.com/otoolej/memeff_TFDs
209
-
210
- The following functions are implemented for the computation of the TFD:
211
- - nonsep_gdtfd: Computes the generalized time-frequency distribution using a non-separable kernel.
212
- - get_analytic_signal: Generates the analytic signal of the input signal.
213
- - gen_analytic: Generates the analytic signal by zero-padding and performing FFT.
214
- - gen_time_lag: Generates the time-lag distribution of the analytic signal.
215
- - multiply_kernel_signal: Multiplies the TFD by the Doppler-lag kernel.
216
- - gen_doppler_lag_kern: Generates the Doppler-lag kernel based on kernel type and parameters.
208
+ This module contains the implementation of the Generalized
209
+ Time-Frequency Distribution (TFD) computation using non-separable
210
+ kernels. This is a Python implementation of the MATLAB code provided
211
+ by John O Toole in the following repository:
212
+ https://github.com/otoolej/memeff_TFDs
213
+
214
+ The following functions are implemented for the computation of the
215
+ TFD:
216
+ - nonsep_gdtfd: Computes the generalized time-frequency
217
+ distribution using a non-separable kernel.
218
+ - get_analytic_signal: Generates the analytic signal of the input
219
+ signal.
220
+ - gen_analytic: Generates the analytic signal by zero-padding and
221
+ performing FFT.
222
+ - gen_time_lag: Generates the time-lag distribution of the
223
+ analytic signal.
224
+ - multiply_kernel_signal: Multiplies the TFD by the Doppler-lag
225
+ kernel.
226
+ - gen_doppler_lag_kern: Generates the Doppler-lag kernel based on
227
+ kernel type and parameters.
217
228
  - get_kern: Gets the kernel based on the provided kernel type.
218
229
  - get_window: General function to calculate a window function.
219
230
  - get_win: Helper function to create the specified window type.
220
- - shift_window: Shifts the window so that positive indices appear first.
231
+ - shift_window: Shifts the window so that positive indices appear
232
+ first.
221
233
  - pad_window: Zero-pads the window to a specified length.
222
- - compute_tfd: Finalizes the time-frequency distribution computation.
223
- """
234
+ - compute_tfd: Finalizes the time-frequency distribution
235
+ computation.
236
+ """
224
237
  pass
225
238
 
226
239
  def nonsep_gdtfd(
227
- self,
228
- x: np.ndarray,
229
- kern_type: None | str = None,
230
- kern_params: None | dict = None
231
- ):
240
+ self,
241
+ x: np.ndarray,
242
+ kern_type: None | str = None,
243
+ kern_params: None | dict = None,
244
+ ):
232
245
  """
233
- Computes the generalized time-frequency distribution (TFD) using a non-separable kernel.
246
+ Computes the generalized time-frequency distribution (TFD) using a
247
+ non-separable kernel.
234
248
 
235
249
  Parameters:
236
250
  -----------
@@ -247,16 +261,20 @@ class TimeFreqDistr:
247
261
  sep - kernel for separable kernel (combintation of SWVD and PWVD)
248
262
 
249
263
  kern_params : dict, optional
250
- Dictionary of parameters specific to the kernel type. Default is None.
251
- The structure of the dictionary depends on the selected kernel type:
264
+ Dictionary of parameters specific to the kernel type. Default is
265
+ None. The structure of the dictionary depends on the selected
266
+ kernel type:
252
267
  - wvd:
253
268
  An empty dictionary, as no additional parameters are required.
254
269
  - swvd:
255
270
  Dictionary with the following keys:
256
271
  'win_length': Length of the smoothing window.
257
- 'win_type': Type of window function (e.g., 'hamm', 'hann').
258
- 'win_param' (optional): Additional parameters for the window.
259
- 'win_param2' (optional): 0 for time-domain window or 1 for Doppler-domain window.
272
+ 'win_type': Type of window function (e.g., 'hamm',
273
+ 'hann').
274
+ 'win_param' (optional): Additional parameters for the
275
+ window.
276
+ 'win_param2' (optional): 0 for time-domain window or 1
277
+ for Doppler-domain window.
260
278
 
261
279
  Example:
262
280
  ```python
@@ -272,7 +290,8 @@ class TimeFreqDistr:
272
290
  'win_length': Length of the smoothing window.
273
291
  'win_type': Type of window function (e.g., 'cosh').
274
292
  'win_param' (optional): Additional parameters for the window.
275
- 'win_param2' (optional): 0 for time-domain window or 1 for Doppler-domain window.
293
+ 'win_param2' (optional): 0 for time-domain window or 1 for
294
+ Doppler-domain window.
276
295
  Example:
277
296
  ```python
278
297
  kern_params = {
@@ -282,18 +301,23 @@ class TimeFreqDistr:
282
301
  }
283
302
  ```
284
303
  - sep:
285
- Dictionary containing two nested dictionaries, one for the Doppler window and one for the lag window:
304
+ Dictionary containing two nested dictionaries, one for the Doppler
305
+ window and one for the lag window:
286
306
  'doppler': {
287
307
  'win_length': Length of the Doppler-domain window.
288
308
  'win_type': Type of Doppler-domain window function.
289
- 'win_param' (optional): Additional parameters for the Doppler window.
290
- 'win_param2' (optional): 0 for time-domain window or 1 for Doppler-domain window.
309
+ 'win_param' (optional): Additional parameters for the
310
+ Doppler window.
311
+ 'win_param2' (optional): 0 for time-domain window or 1
312
+ for Doppler-domain window.
291
313
  }
292
314
  'lag': {
293
315
  'win_length': Length of the lag-domain window.
294
316
  'win_type': Type of lag-domain window function.
295
- 'win_param' (optional): Additional parameters for the lag window.
296
- 'win_param2' (optional): 0 for time-domain window or 1 for Doppler-domain window.
317
+ 'win_param' (optional): Additional parameters for the lag
318
+ window.
319
+ 'win_param2' (optional): 0 for time-domain window or 1
320
+ for Doppler-domain window.
297
321
  }
298
322
  Example:
299
323
  ```python
@@ -315,18 +339,18 @@ class TimeFreqDistr:
315
339
  The computed time-frequency distribution.
316
340
  """
317
341
  z = self.get_analytic_signal(x)
318
- N = len(z) // 2 # Since z is a signal of length 2N
319
- Nh = int(np.ceil(N / 2))
342
+ n_len = len(z) // 2 # Since z is a signal of length 2*n_len
343
+ n_half = int(np.ceil(n_len / 2))
320
344
 
321
345
  # Generate the time-lag distribution of the analytic signal
322
346
  tfd = self.gen_time_lag(z)
323
347
 
324
348
  # Multiply the TFD by the Doppler-lag kernel
325
- tfd = self.multiply_kernel_signal(tfd, kern_type, kern_params, N, Nh)
326
-
349
+ tfd = self.multiply_kernel_signal(tfd, kern_type, kern_params, n_len, n_half)
350
+
327
351
  # Finalize the TFD computation
328
- tfd = self.compute_tfd(N, Nh, tfd)
329
-
352
+ tfd = self.compute_tfd(n_len, n_half, tfd)
353
+
330
354
  return tfd
331
355
 
332
356
  def get_analytic_signal(self, x: np.ndarray) -> np.ndarray:
@@ -343,16 +367,17 @@ class TimeFreqDistr:
343
367
  z : ndarray
344
368
  Analytic signal with zero-padded imaginary part.
345
369
  """
346
- N = len(x)
370
+ n_len = len(x)
347
371
 
348
- # Ensure the signal length is even by trimming one sample if odd, since the gen_time_lag function requires an even-length signal
349
- if N % 2 != 0:
372
+ # Ensure the signal length is even by trimming one sample if odd,
373
+ # since the gen_time_lag function requires an even-length signal
374
+ if n_len % 2 != 0:
350
375
  x = x[:-1]
351
376
 
352
- # Make the analytical signal of the real-valued signal z (preprocessed PPG signal)
353
- # doesn't work for input of complex numbers
354
- z = self.gen_analytic(x)
355
-
377
+ # Make the analytical signal of the real-valued signal z
378
+ # (preprocessed PPG signal). Doesn't work for input of complex numbers
379
+ z = self.gen_analytic(x)
380
+
356
381
  return z
357
382
 
358
383
  def gen_analytic(self, x: np.ndarray) -> np.ndarray:
@@ -369,22 +394,22 @@ class TimeFreqDistr:
369
394
  z : ndarray
370
395
  Analytic signal in the time domain with zeroed second half.
371
396
  """
372
- N = len(x)
373
-
397
+ n_len = len(x)
398
+
374
399
  # Zero-pad the signal to double its length
375
- x = np.concatenate((np.real(x), np.zeros(N)))
400
+ x = np.concatenate((np.real(x), np.zeros(n_len)))
376
401
  x_fft = np.fft.fft(x)
377
402
 
378
403
  # Generate the analytic signal in the frequency domain
379
- H = np.empty(2 * N) # Preallocate an array of size 2*N
380
- H[0] = 1 # First element
381
- H[1:N] = 2 # Next N-1 elements
382
- H[N] = 1 # Middle element
383
- H[N+1:] = 0 # Last N-1 elements
384
- z_cb = np.fft.ifft(x_fft * H)
404
+ h_analytic = np.empty(2 * n_len) # Preallocate an array of size 2*n_len
405
+ h_analytic[0] = 1 # First element
406
+ h_analytic[1:n_len] = 2 # Next n_len-1 elements
407
+ h_analytic[n_len] = 1 # Middle element
408
+ h_analytic[n_len + 1 :] = 0 # Last n_len-1 elements
409
+ z_cb = np.fft.ifft(x_fft * h_analytic)
385
410
 
386
411
  # Force the second half of the time-domain signal to zero
387
- z = np.concatenate((z_cb[:N], np.zeros(N)))
412
+ z = np.concatenate((z_cb[:n_len], np.zeros(n_len)))
388
413
 
389
414
  return z
390
415
 
@@ -396,43 +421,43 @@ class TimeFreqDistr:
396
421
  -----------
397
422
  z : ndarray
398
423
  Analytic signal of the input signal x.
399
-
424
+
400
425
  Returns:
401
426
  --------
402
427
  tfd : ndarray
403
428
  Time-lag distribution of the analytic signal z.
404
429
 
405
430
  """
406
- N = len(z) // 2 # Assuming z is a signal of length 2N
407
- Nh = int(np.ceil(N / 2))
431
+ n_len = len(z) // 2 # Assuming z is a signal of length 2*n_len
432
+ n_half = int(np.ceil(n_len / 2))
408
433
 
409
434
  # Initialize the time-frequency distribution (TFD) matrix
410
- tfd = np.zeros((N, N), dtype=complex)
435
+ tfd = np.zeros((n_len, n_len), dtype=complex)
436
+
437
+ m = np.arange(n_half)
411
438
 
412
- m = np.arange(Nh)
413
-
414
439
  # Loop over time indices
415
- for n in range(N):
416
- inp = np.mod(n + m, 2 * N)
417
- inn = np.mod(n - m, 2 * N)
440
+ for n in range(n_len):
441
+ inp = np.mod(n + m, 2 * n_len)
442
+ inn = np.mod(n - m, 2 * n_len)
418
443
 
419
444
  # Extract the time slice from the analytic signal
420
- K_time_slice = z[inp] * np.conj(z[inn])
445
+ k_time_slice = z[inp] * np.conj(z[inn])
421
446
 
422
447
  # Store real and imaginary parts
423
- tfd[n, :Nh] = np.real(K_time_slice)
424
- tfd[n, Nh:] = np.imag(K_time_slice)
425
-
448
+ tfd[n, :n_half] = np.real(k_time_slice)
449
+ tfd[n, n_half:] = np.imag(k_time_slice)
450
+
426
451
  return tfd
427
452
 
428
- def multiply_kernel_signal(
453
+ def multiply_kernel_signal(
429
454
  self,
430
- tfd: np.ndarray,
431
- kern_type: str,
432
- kern_params: dict,
433
- N: int,
434
- Nh: int
435
- ) -> np.ndarray:
455
+ tfd: np.ndarray,
456
+ kern_type: str,
457
+ kern_params: dict,
458
+ n_len: int,
459
+ n_half: int,
460
+ ) -> np.ndarray:
436
461
  """
437
462
  Multiplies the TFD by the Doppler-lag kernel.
438
463
 
@@ -455,29 +480,25 @@ class TimeFreqDistr:
455
480
  Modified TFD after kernel multiplication.
456
481
  """
457
482
  # Loop over lag indices
458
- for m in range(Nh):
483
+ for m in range(n_half):
459
484
  # Generate the Doppler-lag kernel for each lag index
460
- g_lag_slice = self.gen_doppler_lag_kern(kern_type, kern_params, N, m)
461
-
485
+ g_lag_slice = self.gen_doppler_lag_kern(kern_type, kern_params, n_len, m)
486
+
462
487
  # Extract and transform the TFD slice for this lag
463
- tfd_slice = np.fft.fft(tfd[:, m]) + 1j * np.fft.fft(tfd[:, Nh + m])
464
-
488
+ tfd_slice = np.fft.fft(tfd[:, m]) + 1j * np.fft.fft(tfd[:, n_half + m])
489
+
465
490
  # Multiply by the kernel and perform inverse FFT
466
- R_lag_slice = np.fft.ifft(tfd_slice * g_lag_slice)
467
-
491
+ r_lag_slice = np.fft.ifft(tfd_slice * g_lag_slice)
492
+
468
493
  # Store real and imaginary parts back into the TFD
469
- tfd[:, m] = np.real(R_lag_slice)
470
- tfd[:, Nh + m] = np.imag(R_lag_slice)
471
-
494
+ tfd[:, m] = np.real(r_lag_slice)
495
+ tfd[:, n_half + m] = np.imag(r_lag_slice)
496
+
472
497
  return tfd
473
498
 
474
499
  def gen_doppler_lag_kern(
475
- self,
476
- kern_type: str,
477
- kern_params: dict,
478
- N: int,
479
- lag_index: int
480
- ):
500
+ self, kern_type: str, kern_params: dict, n_len: int, lag_index: int
501
+ ):
481
502
  """
482
503
  Generate the Doppler-lag kernel based on kernel type and parameters.
483
504
 
@@ -487,7 +508,7 @@ class TimeFreqDistr:
487
508
  Type of kernel (e.g., 'wvd', 'swvd', 'pwvd', etc.).
488
509
  kern_params : dict
489
510
  Parameters for the kernel.
490
- N : int
511
+ n_len : int
491
512
  Signal length.
492
513
  lag_index : int
493
514
  Current lag index.
@@ -497,21 +518,21 @@ class TimeFreqDistr:
497
518
  g : ndarray
498
519
  Doppler-lag kernel for the given lag.
499
520
  """
500
- g = np.zeros(N, dtype=complex) # Initialize the kernel
521
+ g = np.zeros(n_len, dtype=complex) # Initialize the kernel
501
522
 
502
523
  # Get kernel based on the type
503
- g = self.get_kern(g, lag_index, kern_type, kern_params, N)
524
+ g = self.get_kern(g, lag_index, kern_type, kern_params, n_len)
504
525
 
505
- return np.real(g) # All kernels are real valued
526
+ return np.real(g) # All kernels are real valued
506
527
 
507
528
  def get_kern(
508
- self,
509
- g: np.ndarray,
510
- lag_index: int,
511
- kern_type: str,
512
- kern_params: dict,
513
- N: int
514
- ) -> np.ndarray:
529
+ self,
530
+ g: np.ndarray,
531
+ lag_index: int,
532
+ kern_type: str,
533
+ kern_params: dict,
534
+ n_len: int,
535
+ ) -> np.ndarray:
515
536
  """
516
537
  Get the kernel based on the provided kernel type.
517
538
 
@@ -525,7 +546,7 @@ class TimeFreqDistr:
525
546
  Type of kernel to use (now included: 'wvd', 'swvd', 'pwvd', 'sep').
526
547
  kern_params : dict
527
548
  Parameters for the specified kernel.
528
- N : int
549
+ n_len : int
529
550
  Signal length.
530
551
 
531
552
  Returns:
@@ -534,58 +555,72 @@ class TimeFreqDistr:
534
555
  Kernel function at the current lag.
535
556
  """
536
557
  # Validate kern_type
537
- valid_kern_types = ['wvd', 'sep', 'swvd', 'pwvd'] # List of valid kernel types which are currently supported
558
+ valid_kern_types = [
559
+ "wvd",
560
+ "sep",
561
+ "swvd",
562
+ "pwvd",
563
+ ] # List of valid kernel types which are currently supported
538
564
  if kern_type not in valid_kern_types:
539
- raise ValueError(f"Unknown kernel type: {kern_type}. Expected one of {valid_kern_types}")
540
-
565
+ raise ValueError(
566
+ f"Unknown kernel type: {kern_type}. Expected one of {valid_kern_types}"
567
+ )
568
+
541
569
  num_params = len(kern_params)
542
570
 
543
- if kern_type == 'wvd':
544
- g[:] = 1 # WVD kernel is the equal to 1 for all lags
571
+ if kern_type == "wvd":
572
+ g[:] = 1 # WVD kernel is the equal to 1 for all lags
545
573
 
546
- elif kern_type == 'sep':
574
+ elif kern_type == "sep":
547
575
  # Separable Kernel
548
576
  g1 = np.copy(g) # Create a new array for g1
549
577
  g2 = np.copy(g) # Create a new array for g2
550
-
551
- # Call recursively to obtain g1 and g2 kernels (no in-place modification of g)
552
- g1 = self.get_kern(g1, lag_index, 'swvd', kern_params['lag'], N) # Generate the first kernel
553
- g2 = self.get_kern(g2, lag_index, 'pwvd', kern_params['doppler'], N) # Generate the second kernel
554
- g = g1 * g2 # Multiply the two kernels to obtain the separable kernel
578
+
579
+ # Call recursively to obtain g1 and g2 kernels (no in-place
580
+ # modification of g)
581
+ g1 = self.get_kern(
582
+ g1, lag_index, "swvd", kern_params["lag"], n_len
583
+ ) # Generate the first kernel
584
+ g2 = self.get_kern(
585
+ g2, lag_index, "pwvd", kern_params["doppler"], n_len
586
+ ) # Generate the second kernel
587
+ g = g1 * g2 # Multiply the two kernels to obtain the separable kernel
555
588
 
556
589
  else:
557
590
  if num_params < 2:
558
- raise ValueError("Missing required kernel parameters: 'win_length' and 'win_type'")
591
+ raise ValueError(
592
+ "Missing required kernel parameters: 'win_length' and 'win_type'"
593
+ )
559
594
 
560
- win_length = kern_params['win_length']
561
- win_type = kern_params['win_type']
562
- win_param = kern_params['win_param'] if 'win_param' in kern_params else 0
563
- win_param2 = kern_params['win_param2'] if 'win_param2' in kern_params else 1
595
+ win_length = kern_params["win_length"]
596
+ win_type = kern_params["win_type"]
597
+ win_param = kern_params["win_param"] if "win_param" in kern_params else 0
598
+ win_param2 = kern_params["win_param2"] if "win_param2" in kern_params else 1
564
599
 
565
- G = self.get_window(win_length, win_type, win_param)
566
- G = self.pad_window(G, N)
600
+ g_window = self.get_window(win_length, win_type, win_param)
601
+ g_window = self.pad_window(g_window, n_len)
567
602
 
568
- if kern_type == 'swvd' and win_param2 == 0:
569
- G = np.fft.fft(G)
570
- if G[0] != 0: # add this check to avoid division by zero
571
- G /= G[0]
572
- G = G[lag_index]
603
+ if kern_type == "swvd" and win_param2 == 0:
604
+ g_window = np.fft.fft(g_window)
605
+ if g_window[0] != 0: # add this check to avoid division by zero
606
+ g_window /= g_window[0]
607
+ g_window = g_window[lag_index]
573
608
 
574
- g[:] = G
609
+ g[:] = g_window
575
610
 
576
611
  return g
577
612
 
578
613
  def get_window(
579
- self,
580
- win_length: int,
581
- win_type: str,
582
- win_param: float | None = None,
583
- dft_window: bool = False,
584
- Npad: int = 0
585
- ) -> np.ndarray:
614
+ self,
615
+ win_length: int,
616
+ win_type: str,
617
+ win_param: float | None = None,
618
+ dft_window: bool = False,
619
+ n_pad: int = 0,
620
+ ) -> np.ndarray:
586
621
  """
587
622
  General function to calculate a window function.
588
-
623
+
589
624
  Parameters:
590
625
  -----------
591
626
  win_length : int
@@ -597,37 +632,37 @@ class TimeFreqDistr:
597
632
  Window parameter (e.g., alpha for Gaussian window). Default is None.
598
633
  dft_window : bool, optional
599
634
  If True, returns the DFT of the window. Default is False.
600
- Npad : int, optional
601
- If greater than 0, zero-pads the window to length Npad. Default is 0.
602
-
635
+ n_pad : int, optional
636
+ If greater than 0, zero-pads the window to length n_pad. Default is 0.
637
+
603
638
  Returns:
604
639
  --------
605
640
  win : ndarray
606
641
  The calculated window (or its DFT if dft_window is True).
607
642
  """
608
-
643
+
609
644
  # Get the window
610
645
  win = self.get_win(win_length, win_type, win_param, dft_window)
611
-
646
+
612
647
  # Shift the window so that positive indices are first
613
648
  win = self.shift_window(win)
614
-
615
- # Zero-pad the window to length Npad if necessary
616
- if Npad > 0:
617
- win = self.pad_window(win, Npad)
618
-
649
+
650
+ # Zero-pad the window to length n_pad if necessary
651
+ if n_pad > 0:
652
+ win = self.pad_window(win, n_pad)
653
+
619
654
  return win
620
655
 
621
656
  def get_win(
622
- self,
623
- win_length: int,
624
- win_type: str,
625
- win_param: float | None = None,
626
- dft_window: bool = False
627
- ) -> np.ndarray:
657
+ self,
658
+ win_length: int,
659
+ win_type: str,
660
+ win_param: float | None = None,
661
+ dft_window: bool = False,
662
+ ) -> np.ndarray:
628
663
  """
629
664
  Helper function to create the specified window type.
630
-
665
+
631
666
  Parameters:
632
667
  -----------
633
668
  win_length : int
@@ -635,115 +670,115 @@ class TimeFreqDistr:
635
670
  win_type : str
636
671
  Type of window.
637
672
  win_param : float, optional
638
- Additional parameter for certain window types (e.g., Gaussian alpha). Default is None.
673
+ Additional parameter for certain window types (e.g., Gaussian
674
+ alpha). Default is None.
639
675
  dft_window : bool, optional
640
676
  If True, returns the DFT of the window. Default is False.
641
-
677
+
642
678
  Returns:
643
679
  --------
644
680
  win : ndarray
645
681
  The created window (or its DFT if dft_window is True).
646
682
  """
647
- if win_type == 'delta':
683
+ if win_type == "delta":
648
684
  win = np.zeros(win_length)
649
685
  win[win_length // 2] = 1
650
- elif win_type == 'rect':
686
+ elif win_type == "rect":
651
687
  win = np.ones(win_length)
652
- elif win_type in ['hamm', 'hamming']:
688
+ elif win_type in ["hamm", "hamming"]:
653
689
  win = signal.windows.hamming(win_length)
654
- elif win_type in ['hann', 'hanning']:
690
+ elif win_type in ["hann", "hanning"]:
655
691
  win = signal.windows.hann(win_length)
656
- elif win_type == 'gauss':
657
- win = signal.windows.gaussian(win_length, std=win_param if win_param else 0.4)
658
- elif win_type == 'cosh':
692
+ elif win_type == "gauss":
693
+ win = signal.windows.gaussian(
694
+ win_length, std=win_param if win_param else 0.4
695
+ )
696
+ elif win_type == "cosh":
659
697
  win_hlf = win_length // 2
660
698
  if not win_param:
661
699
  win_param = 0.01
662
- win = np.array([np.cosh(m) ** (-2 * win_param) for m in range(-win_hlf, win_hlf+1)])
700
+ win = np.array(
701
+ [np.cosh(m) ** (-2 * win_param) for m in range(-win_hlf, win_hlf + 1)]
702
+ )
663
703
  win = np.fft.fftshift(win)
664
704
  else:
665
705
  raise ValueError(f"Unknown window type {win_type}")
666
-
706
+
667
707
  # If dft_window is True, return the DFT of the window
668
708
  if dft_window:
669
709
  win = np.fft.fft(np.roll(win, win_length // 2))
670
710
  win = np.roll(win, -win_length // 2)
671
-
711
+
672
712
  return win
673
713
 
674
714
  def shift_window(self, w: np.ndarray) -> np.ndarray:
675
715
  """
676
716
  Shift the window so that positive indices appear first.
677
-
717
+
678
718
  Parameters:
679
719
  -----------
680
720
  w : ndarray
681
721
  Window to be shifted.
682
-
722
+
683
723
  Returns:
684
724
  --------
685
725
  w_shifted : ndarray
686
726
  Shifted window with positive indices first.
687
727
  """
688
- N = len(w)
689
- return np.roll(w, N // 2)
728
+ n_len = len(w)
729
+ return np.roll(w, n_len // 2)
690
730
 
691
- def pad_window(self, w: np.ndarray, Npad: int) -> np.ndarray:
731
+ def pad_window(self, w: np.ndarray, n_pad: int) -> np.ndarray:
692
732
  """
693
733
  Zero-pad the window to a specified length.
694
-
734
+
695
735
  Parameters:
696
736
  -----------
697
737
  w : ndarray
698
738
  The original window.
699
- Npad : int
739
+ n_pad : int
700
740
  Length to zero-pad the window to.
701
-
741
+
702
742
  Returns:
703
743
  --------
704
744
  w_pad : ndarray
705
- Zero-padded window of length Npad.
706
-
745
+ Zero-padded window of length n_pad.
746
+
707
747
  Raises:
708
748
  -------
709
749
  ValueError:
710
- If Npad is less than the original window length.
750
+ If n_pad is less than the original window length.
711
751
  """
712
- N = len(w)
713
- w_pad = np.zeros(Npad)
714
- Nh = N // 2
715
-
716
- if Npad < N:
717
- raise ValueError("Npad must be greater than or equal to the window length")
718
-
719
- if N == Npad:
752
+ n_len = len(w)
753
+ w_pad = np.zeros(n_pad)
754
+ n_half = n_len // 2
755
+
756
+ if n_pad < n_len:
757
+ raise ValueError("n_pad must be greater than or equal to the window length")
758
+
759
+ if n_len == n_pad:
720
760
  return w
721
-
722
- if N % 2 == 1: # For odd N
723
- w_pad[:Nh+1] = w[:Nh+1]
724
- w_pad[-Nh:] = w[-Nh:]
725
- else: # For even N
726
- w_pad[:Nh] = w[:Nh]
727
- w_pad[Nh] = w[Nh] / 2
728
- w_pad[-Nh:] = w[-Nh:]
729
- w_pad[-Nh] = w[Nh] / 2
730
-
761
+
762
+ if n_len % 2 == 1: # For odd n_len
763
+ w_pad[: n_half + 1] = w[: n_half + 1]
764
+ w_pad[-n_half:] = w[-n_half:]
765
+ else: # For even n_len
766
+ w_pad[:n_half] = w[:n_half]
767
+ w_pad[n_half] = w[n_half] / 2
768
+ w_pad[-n_half:] = w[-n_half:]
769
+ w_pad[-n_half] = w[n_half] / 2
770
+
731
771
  return w_pad
732
772
 
733
- def compute_tfd(
734
- self,
735
- N: int,
736
- Nh: int,
737
- tfd: np.ndarray
738
- ):
773
+ def compute_tfd(self, n_len: int, n_half: int, tfd: np.ndarray):
739
774
  """
740
775
  Finalizes the time-frequency distribution computation.
741
776
 
742
777
  Parameters:
743
778
  -----------
744
- N : int
779
+ n_len : int
745
780
  Size of the TFD.
746
- Nh : int
781
+ n_half : int
747
782
  Half-length parameter.
748
783
  tfd : np.ndarray
749
784
  Time-frequency distribution to be finalized.
@@ -751,30 +786,36 @@ class TimeFreqDistr:
751
786
  Returns:
752
787
  --------
753
788
  tfd : np.ndarray
754
- Final computed TFD (N,N).
789
+ Final computed TFD (n_len, n_len).
755
790
  """
756
- m = np.arange(0, Nh) # m = 0:(Nh-1)
757
- mb = np.arange(1, Nh) # mb = 1:(Nh-1)
791
+ m = np.arange(0, n_half) # m = 0:(n_half-1)
792
+ mb = np.arange(1, n_half) # mb = 1:(n_half-1)
793
+
794
+ for n in range(0, n_len - 1, 2): # for n=0:2:n_len-2
795
+ r_even_half = np.complex128(tfd[n, :n_half]) + 1j * np.complex128(
796
+ tfd[n, n_half:]
797
+ )
798
+ r_odd_half = np.complex128(tfd[n + 1, :n_half]) + 1j * np.complex128(
799
+ tfd[n + 1, n_half:]
800
+ )
758
801
 
759
- for n in range(0, N-1, 2): # for n=0:2:N-2
760
- R_even_half = np.complex128(tfd[n, :Nh]) + 1j * np.complex128(tfd[n, Nh:])
761
- R_odd_half = np.complex128(tfd[n+1, :Nh]) + 1j * np.complex128(tfd[n+1, Nh:])
802
+ r_tslice_even = np.zeros(n_len, dtype=np.complex128)
803
+ r_tslice_odd = np.zeros(n_len, dtype=np.complex128)
762
804
 
763
- R_tslice_even = np.zeros(N, dtype=np.complex128)
764
- R_tslice_odd = np.zeros(N, dtype=np.complex128)
805
+ r_tslice_even[m] = r_even_half
806
+ r_tslice_odd[m] = r_odd_half
765
807
 
766
- R_tslice_even[m] = R_even_half
767
- R_tslice_odd[m] = R_odd_half
808
+ r_tslice_even[n_len - mb] = np.conj(r_even_half[mb])
809
+ r_tslice_odd[n_len - mb] = np.conj(r_odd_half[mb])
768
810
 
769
- R_tslice_even[N-mb] = np.conj(R_even_half[mb])
770
- R_tslice_odd[N-mb] = np.conj(R_odd_half[mb])
771
-
772
811
  # Perform FFT to compute time slices
773
- tfd_time_slice = np.fft.fft(R_tslice_even + 1j * R_tslice_odd)
812
+ tfd_time_slice = np.fft.fft(r_tslice_even + 1j * r_tslice_odd)
774
813
 
775
814
  tfd[n, :] = np.real(tfd_time_slice)
776
- tfd[n+1, :] = np.imag(tfd_time_slice)
815
+ tfd[n + 1, :] = np.imag(tfd_time_slice)
777
816
 
778
- tfd = tfd / N # Normalize the TFD
779
- tfd = tfd.transpose() # Transpose the TFD to have the time on the x-axis and frequency on the y-axis
817
+ tfd = tfd / n_len # Normalize the TFD
818
+ # Transpose the TFD to have the time on the x-axis and frequency on
819
+ # the y-axis
820
+ tfd = tfd.transpose()
780
821
  return tfd