paradigma 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,14 @@
1
+ from typing import Tuple
2
+
1
3
  import numpy as np
2
4
  from scipy import signal
3
- from typing import Tuple
4
5
 
5
6
  from paradigma.config import PulseRateConfig
6
7
 
7
8
 
8
9
  def assign_sqa_label(
9
- ppg_prob: np.ndarray,
10
- config: PulseRateConfig,
11
- acc_label=None
12
- ) -> np.ndarray:
10
+ ppg_prob: np.ndarray, config: PulseRateConfig, acc_label: np.ndarray | None = None
11
+ ) -> np.ndarray:
13
12
  """
14
13
  Assigns a signal quality label to every individual data point.
15
14
 
@@ -43,12 +42,14 @@ def assign_sqa_label(
43
42
 
44
43
  for i in range(n_samples):
45
44
  # Start and end indices for current epoch
46
- start_idx = max(0, int((i - (samples_per_epoch - samples_shift)) // fs)) # max to handle first epochs
47
- end_idx = min(int(i // fs), len(ppg_prob)) # min to handle last epochs
45
+ start_idx = max(
46
+ 0, int((i - (samples_per_epoch - samples_shift)) // fs)
47
+ ) # max to handle first epochs
48
+ end_idx = min(int(i // fs), len(ppg_prob)) # min to handle last epochs
48
49
 
49
50
  # Extract probabilities and labels for the current epoch
50
- prob = ppg_prob[start_idx:end_idx+1]
51
- label_imu = acc_label[start_idx:end_idx+1]
51
+ prob = ppg_prob[start_idx : end_idx + 1]
52
+ label_imu = acc_label[start_idx : end_idx + 1]
52
53
 
53
54
  # Calculate mean probability and majority voting for labels
54
55
  data_prob[i] = np.mean(prob)
@@ -61,7 +62,9 @@ def assign_sqa_label(
61
62
  return sqa_label
62
63
 
63
64
 
64
- def extract_pr_segments(sqa_label: np.ndarray, min_pr_samples: int) -> Tuple[np.ndarray, np.ndarray]:
65
+ def extract_pr_segments(
66
+ sqa_label: np.ndarray, min_pr_samples: int
67
+ ) -> Tuple[np.ndarray, np.ndarray]:
65
68
  """
66
69
  Extracts pulse rate segments based on the SQA label.
67
70
 
@@ -95,12 +98,8 @@ def extract_pr_segments(sqa_label: np.ndarray, min_pr_samples: int) -> Tuple[np.
95
98
 
96
99
 
97
100
  def extract_pr_from_segment(
98
- ppg: np.ndarray,
99
- tfd_length: int,
100
- fs: int,
101
- kern_type: str,
102
- kern_params: dict
103
- ) -> np.ndarray:
101
+ ppg: np.ndarray, tfd_length: int, fs: int, kern_type: str, kern_params: dict
102
+ ) -> np.ndarray:
104
103
  """
105
104
  Extracts pulse rate from the time-frequency distribution of the PPG signal.
106
105
 
@@ -115,7 +114,7 @@ def extract_pr_from_segment(
115
114
  kern_type : str
116
115
  Type of TFD kernel to use (e.g., 'wvd' for Wigner-Ville distribution).
117
116
  kern_params : dict
118
- Parameters for the specified kernel. Not required for 'wvd', but relevant for other
117
+ Parameters for the specified kernel. Not required for 'wvd', but relevant for other
119
118
  kernels like 'spwvd' or 'swvd'. Default is None.
120
119
 
121
120
  Returns
@@ -149,19 +148,16 @@ def extract_pr_from_segment(
149
148
  for segment in ppg_segments:
150
149
  # Calculate the time-frequency distribution
151
150
  pr_tfd = extract_pr_with_tfd(segment, fs, kern_type, kern_params)
152
- pr_est_from_ppg = np.concatenate((pr_est_from_ppg, pr_tfd))
151
+ pr_est_from_ppg = np.concatenate((pr_est_from_ppg, pr_tfd))
153
152
 
154
153
  return pr_est_from_ppg
155
154
 
156
155
 
157
156
  def extract_pr_with_tfd(
158
- ppg: np.ndarray,
159
- fs: int,
160
- kern_type: str,
161
- kern_params: dict
162
- ) -> np.ndarray:
157
+ ppg: np.ndarray, fs: int, kern_type: str, kern_params: dict
158
+ ) -> np.ndarray:
163
159
  """
164
- Estimate pulse rate (PR) from a PPG segment using a TFD method with optional
160
+ Estimate pulse rate (PR) from a PPG segment using a TFD method with optional
165
161
  moving average filtering.
166
162
 
167
163
  Parameters
@@ -193,10 +189,14 @@ def extract_pr_with_tfd(
193
189
  max_freq_indices = np.argmax(tfd, axis=0)
194
190
 
195
191
  pr_smooth_tfd = np.array([])
196
- for i in range(2, int(len(ppg) / fs) - 4 + 1, 2): # Skip the first and last 2 seconds, add 1 to include the last segment
192
+ for i in range(
193
+ 2, int(len(ppg) / fs) - 4 + 1, 2
194
+ ): # Skip the first and last 2 seconds, add 1 to include the last segment
197
195
  relevant_indices = (time_axis >= i) & (time_axis < i + 2)
198
196
  avg_frequency = np.mean(freq_axis[max_freq_indices[relevant_indices]])
199
- pr_smooth_tfd = np.concatenate((pr_smooth_tfd, [60 * avg_frequency])) # Convert frequency to BPM
197
+ pr_smooth_tfd = np.concatenate(
198
+ (pr_smooth_tfd, [60 * avg_frequency])
199
+ ) # Convert frequency to BPM
200
200
 
201
201
  return pr_smooth_tfd
202
202
 
@@ -205,7 +205,7 @@ class TimeFreqDistr:
205
205
  def __init__(self):
206
206
  """
207
207
  This module contains the implementation of the Generalized Time-Frequency Distribution (TFD) computation using non-separable kernels.
208
- This is a Python implementation of the MATLAB code provided by John O Toole in the following repository: https://github.com/otoolej/memeff_TFDs
208
+ This is a Python implementation of the MATLAB code provided by John O Toole in the following repository: https://github.com/otoolej/memeff_TFDs
209
209
 
210
210
  The following functions are implemented for the computation of the TFD:
211
211
  - nonsep_gdtfd: Computes the generalized time-frequency distribution using a non-separable kernel.
@@ -220,15 +220,15 @@ class TimeFreqDistr:
220
220
  - shift_window: Shifts the window so that positive indices appear first.
221
221
  - pad_window: Zero-pads the window to a specified length.
222
222
  - compute_tfd: Finalizes the time-frequency distribution computation.
223
- """
223
+ """
224
224
  pass
225
225
 
226
226
  def nonsep_gdtfd(
227
- self,
228
- x: np.ndarray,
229
- kern_type: None | str = None,
230
- kern_params: None | dict = None
231
- ):
227
+ self,
228
+ x: np.ndarray,
229
+ kern_type: None | str = None,
230
+ kern_params: None | dict = None,
231
+ ):
232
232
  """
233
233
  Computes the generalized time-frequency distribution (TFD) using a non-separable kernel.
234
234
 
@@ -323,10 +323,10 @@ class TimeFreqDistr:
323
323
 
324
324
  # Multiply the TFD by the Doppler-lag kernel
325
325
  tfd = self.multiply_kernel_signal(tfd, kern_type, kern_params, N, Nh)
326
-
326
+
327
327
  # Finalize the TFD computation
328
328
  tfd = self.compute_tfd(N, Nh, tfd)
329
-
329
+
330
330
  return tfd
331
331
 
332
332
  def get_analytic_signal(self, x: np.ndarray) -> np.ndarray:
@@ -351,8 +351,8 @@ class TimeFreqDistr:
351
351
 
352
352
  # Make the analytical signal of the real-valued signal z (preprocessed PPG signal)
353
353
  # doesn't work for input of complex numbers
354
- z = self.gen_analytic(x)
355
-
354
+ z = self.gen_analytic(x)
355
+
356
356
  return z
357
357
 
358
358
  def gen_analytic(self, x: np.ndarray) -> np.ndarray:
@@ -370,17 +370,17 @@ class TimeFreqDistr:
370
370
  Analytic signal in the time domain with zeroed second half.
371
371
  """
372
372
  N = len(x)
373
-
373
+
374
374
  # Zero-pad the signal to double its length
375
375
  x = np.concatenate((np.real(x), np.zeros(N)))
376
376
  x_fft = np.fft.fft(x)
377
377
 
378
378
  # Generate the analytic signal in the frequency domain
379
- H = np.empty(2 * N) # Preallocate an array of size 2*N
380
- H[0] = 1 # First element
381
- H[1:N] = 2 # Next N-1 elements
382
- H[N] = 1 # Middle element
383
- H[N+1:] = 0 # Last N-1 elements
379
+ H = np.empty(2 * N) # Preallocate an array of size 2*N
380
+ H[0] = 1 # First element
381
+ H[1:N] = 2 # Next N-1 elements
382
+ H[N] = 1 # Middle element
383
+ H[N + 1 :] = 0 # Last N-1 elements
384
384
  z_cb = np.fft.ifft(x_fft * H)
385
385
 
386
386
  # Force the second half of the time-domain signal to zero
@@ -396,7 +396,7 @@ class TimeFreqDistr:
396
396
  -----------
397
397
  z : ndarray
398
398
  Analytic signal of the input signal x.
399
-
399
+
400
400
  Returns:
401
401
  --------
402
402
  tfd : ndarray
@@ -410,7 +410,7 @@ class TimeFreqDistr:
410
410
  tfd = np.zeros((N, N), dtype=complex)
411
411
 
412
412
  m = np.arange(Nh)
413
-
413
+
414
414
  # Loop over time indices
415
415
  for n in range(N):
416
416
  inp = np.mod(n + m, 2 * N)
@@ -422,17 +422,12 @@ class TimeFreqDistr:
422
422
  # Store real and imaginary parts
423
423
  tfd[n, :Nh] = np.real(K_time_slice)
424
424
  tfd[n, Nh:] = np.imag(K_time_slice)
425
-
425
+
426
426
  return tfd
427
427
 
428
- def multiply_kernel_signal(
429
- self,
430
- tfd: np.ndarray,
431
- kern_type: str,
432
- kern_params: dict,
433
- N: int,
434
- Nh: int
435
- ) -> np.ndarray:
428
+ def multiply_kernel_signal(
429
+ self, tfd: np.ndarray, kern_type: str, kern_params: dict, N: int, Nh: int
430
+ ) -> np.ndarray:
436
431
  """
437
432
  Multiplies the TFD by the Doppler-lag kernel.
438
433
 
@@ -458,26 +453,22 @@ class TimeFreqDistr:
458
453
  for m in range(Nh):
459
454
  # Generate the Doppler-lag kernel for each lag index
460
455
  g_lag_slice = self.gen_doppler_lag_kern(kern_type, kern_params, N, m)
461
-
456
+
462
457
  # Extract and transform the TFD slice for this lag
463
458
  tfd_slice = np.fft.fft(tfd[:, m]) + 1j * np.fft.fft(tfd[:, Nh + m])
464
-
459
+
465
460
  # Multiply by the kernel and perform inverse FFT
466
461
  R_lag_slice = np.fft.ifft(tfd_slice * g_lag_slice)
467
-
462
+
468
463
  # Store real and imaginary parts back into the TFD
469
464
  tfd[:, m] = np.real(R_lag_slice)
470
465
  tfd[:, Nh + m] = np.imag(R_lag_slice)
471
-
466
+
472
467
  return tfd
473
468
 
474
469
  def gen_doppler_lag_kern(
475
- self,
476
- kern_type: str,
477
- kern_params: dict,
478
- N: int,
479
- lag_index: int
480
- ):
470
+ self, kern_type: str, kern_params: dict, N: int, lag_index: int
471
+ ):
481
472
  """
482
473
  Generate the Doppler-lag kernel based on kernel type and parameters.
483
474
 
@@ -502,16 +493,11 @@ class TimeFreqDistr:
502
493
  # Get kernel based on the type
503
494
  g = self.get_kern(g, lag_index, kern_type, kern_params, N)
504
495
 
505
- return np.real(g) # All kernels are real valued
496
+ return np.real(g) # All kernels are real valued
506
497
 
507
498
  def get_kern(
508
- self,
509
- g: np.ndarray,
510
- lag_index: int,
511
- kern_type: str,
512
- kern_params: dict,
513
- N: int
514
- ) -> np.ndarray:
499
+ self, g: np.ndarray, lag_index: int, kern_type: str, kern_params: dict, N: int
500
+ ) -> np.ndarray:
515
501
  """
516
502
  Get the kernel based on the provided kernel type.
517
503
 
@@ -534,38 +520,51 @@ class TimeFreqDistr:
534
520
  Kernel function at the current lag.
535
521
  """
536
522
  # Validate kern_type
537
- valid_kern_types = ['wvd', 'sep', 'swvd', 'pwvd'] # List of valid kernel types which are currently supported
523
+ valid_kern_types = [
524
+ "wvd",
525
+ "sep",
526
+ "swvd",
527
+ "pwvd",
528
+ ] # List of valid kernel types which are currently supported
538
529
  if kern_type not in valid_kern_types:
539
- raise ValueError(f"Unknown kernel type: {kern_type}. Expected one of {valid_kern_types}")
540
-
530
+ raise ValueError(
531
+ f"Unknown kernel type: {kern_type}. Expected one of {valid_kern_types}"
532
+ )
533
+
541
534
  num_params = len(kern_params)
542
535
 
543
- if kern_type == 'wvd':
544
- g[:] = 1 # WVD kernel is the equal to 1 for all lags
536
+ if kern_type == "wvd":
537
+ g[:] = 1 # WVD kernel is the equal to 1 for all lags
545
538
 
546
- elif kern_type == 'sep':
539
+ elif kern_type == "sep":
547
540
  # Separable Kernel
548
541
  g1 = np.copy(g) # Create a new array for g1
549
542
  g2 = np.copy(g) # Create a new array for g2
550
-
543
+
551
544
  # Call recursively to obtain g1 and g2 kernels (no in-place modification of g)
552
- g1 = self.get_kern(g1, lag_index, 'swvd', kern_params['lag'], N) # Generate the first kernel
553
- g2 = self.get_kern(g2, lag_index, 'pwvd', kern_params['doppler'], N) # Generate the second kernel
554
- g = g1 * g2 # Multiply the two kernels to obtain the separable kernel
545
+ g1 = self.get_kern(
546
+ g1, lag_index, "swvd", kern_params["lag"], N
547
+ ) # Generate the first kernel
548
+ g2 = self.get_kern(
549
+ g2, lag_index, "pwvd", kern_params["doppler"], N
550
+ ) # Generate the second kernel
551
+ g = g1 * g2 # Multiply the two kernels to obtain the separable kernel
555
552
 
556
553
  else:
557
554
  if num_params < 2:
558
- raise ValueError("Missing required kernel parameters: 'win_length' and 'win_type'")
555
+ raise ValueError(
556
+ "Missing required kernel parameters: 'win_length' and 'win_type'"
557
+ )
559
558
 
560
- win_length = kern_params['win_length']
561
- win_type = kern_params['win_type']
562
- win_param = kern_params['win_param'] if 'win_param' in kern_params else 0
563
- win_param2 = kern_params['win_param2'] if 'win_param2' in kern_params else 1
559
+ win_length = kern_params["win_length"]
560
+ win_type = kern_params["win_type"]
561
+ win_param = kern_params["win_param"] if "win_param" in kern_params else 0
562
+ win_param2 = kern_params["win_param2"] if "win_param2" in kern_params else 1
564
563
 
565
564
  G = self.get_window(win_length, win_type, win_param)
566
565
  G = self.pad_window(G, N)
567
566
 
568
- if kern_type == 'swvd' and win_param2 == 0:
567
+ if kern_type == "swvd" and win_param2 == 0:
569
568
  G = np.fft.fft(G)
570
569
  if G[0] != 0: # add this check to avoid division by zero
571
570
  G /= G[0]
@@ -576,16 +575,16 @@ class TimeFreqDistr:
576
575
  return g
577
576
 
578
577
  def get_window(
579
- self,
580
- win_length: int,
581
- win_type: str,
582
- win_param: float | None = None,
583
- dft_window: bool = False,
584
- Npad: int = 0
585
- ) -> np.ndarray:
578
+ self,
579
+ win_length: int,
580
+ win_type: str,
581
+ win_param: float | None = None,
582
+ dft_window: bool = False,
583
+ Npad: int = 0,
584
+ ) -> np.ndarray:
586
585
  """
587
586
  General function to calculate a window function.
588
-
587
+
589
588
  Parameters:
590
589
  -----------
591
590
  win_length : int
@@ -599,35 +598,35 @@ class TimeFreqDistr:
599
598
  If True, returns the DFT of the window. Default is False.
600
599
  Npad : int, optional
601
600
  If greater than 0, zero-pads the window to length Npad. Default is 0.
602
-
601
+
603
602
  Returns:
604
603
  --------
605
604
  win : ndarray
606
605
  The calculated window (or its DFT if dft_window is True).
607
606
  """
608
-
607
+
609
608
  # Get the window
610
609
  win = self.get_win(win_length, win_type, win_param, dft_window)
611
-
610
+
612
611
  # Shift the window so that positive indices are first
613
612
  win = self.shift_window(win)
614
-
613
+
615
614
  # Zero-pad the window to length Npad if necessary
616
615
  if Npad > 0:
617
616
  win = self.pad_window(win, Npad)
618
-
617
+
619
618
  return win
620
619
 
621
620
  def get_win(
622
- self,
623
- win_length: int,
624
- win_type: str,
625
- win_param: float | None = None,
626
- dft_window: bool = False
627
- ) -> np.ndarray:
621
+ self,
622
+ win_length: int,
623
+ win_type: str,
624
+ win_param: float | None = None,
625
+ dft_window: bool = False,
626
+ ) -> np.ndarray:
628
627
  """
629
628
  Helper function to create the specified window type.
630
-
629
+
631
630
  Parameters:
632
631
  -----------
633
632
  win_length : int
@@ -638,48 +637,52 @@ class TimeFreqDistr:
638
637
  Additional parameter for certain window types (e.g., Gaussian alpha). Default is None.
639
638
  dft_window : bool, optional
640
639
  If True, returns the DFT of the window. Default is False.
641
-
640
+
642
641
  Returns:
643
642
  --------
644
643
  win : ndarray
645
644
  The created window (or its DFT if dft_window is True).
646
645
  """
647
- if win_type == 'delta':
646
+ if win_type == "delta":
648
647
  win = np.zeros(win_length)
649
648
  win[win_length // 2] = 1
650
- elif win_type == 'rect':
649
+ elif win_type == "rect":
651
650
  win = np.ones(win_length)
652
- elif win_type in ['hamm', 'hamming']:
651
+ elif win_type in ["hamm", "hamming"]:
653
652
  win = signal.windows.hamming(win_length)
654
- elif win_type in ['hann', 'hanning']:
653
+ elif win_type in ["hann", "hanning"]:
655
654
  win = signal.windows.hann(win_length)
656
- elif win_type == 'gauss':
657
- win = signal.windows.gaussian(win_length, std=win_param if win_param else 0.4)
658
- elif win_type == 'cosh':
655
+ elif win_type == "gauss":
656
+ win = signal.windows.gaussian(
657
+ win_length, std=win_param if win_param else 0.4
658
+ )
659
+ elif win_type == "cosh":
659
660
  win_hlf = win_length // 2
660
661
  if not win_param:
661
662
  win_param = 0.01
662
- win = np.array([np.cosh(m) ** (-2 * win_param) for m in range(-win_hlf, win_hlf+1)])
663
+ win = np.array(
664
+ [np.cosh(m) ** (-2 * win_param) for m in range(-win_hlf, win_hlf + 1)]
665
+ )
663
666
  win = np.fft.fftshift(win)
664
667
  else:
665
668
  raise ValueError(f"Unknown window type {win_type}")
666
-
669
+
667
670
  # If dft_window is True, return the DFT of the window
668
671
  if dft_window:
669
672
  win = np.fft.fft(np.roll(win, win_length // 2))
670
673
  win = np.roll(win, -win_length // 2)
671
-
674
+
672
675
  return win
673
676
 
674
677
  def shift_window(self, w: np.ndarray) -> np.ndarray:
675
678
  """
676
679
  Shift the window so that positive indices appear first.
677
-
680
+
678
681
  Parameters:
679
682
  -----------
680
683
  w : ndarray
681
684
  Window to be shifted.
682
-
685
+
683
686
  Returns:
684
687
  --------
685
688
  w_shifted : ndarray
@@ -691,19 +694,19 @@ class TimeFreqDistr:
691
694
  def pad_window(self, w: np.ndarray, Npad: int) -> np.ndarray:
692
695
  """
693
696
  Zero-pad the window to a specified length.
694
-
697
+
695
698
  Parameters:
696
699
  -----------
697
700
  w : ndarray
698
701
  The original window.
699
702
  Npad : int
700
703
  Length to zero-pad the window to.
701
-
704
+
702
705
  Returns:
703
706
  --------
704
707
  w_pad : ndarray
705
708
  Zero-padded window of length Npad.
706
-
709
+
707
710
  Raises:
708
711
  -------
709
712
  ValueError:
@@ -712,30 +715,25 @@ class TimeFreqDistr:
712
715
  N = len(w)
713
716
  w_pad = np.zeros(Npad)
714
717
  Nh = N // 2
715
-
718
+
716
719
  if Npad < N:
717
720
  raise ValueError("Npad must be greater than or equal to the window length")
718
721
 
719
722
  if N == Npad:
720
723
  return w
721
-
724
+
722
725
  if N % 2 == 1: # For odd N
723
- w_pad[:Nh+1] = w[:Nh+1]
726
+ w_pad[: Nh + 1] = w[: Nh + 1]
724
727
  w_pad[-Nh:] = w[-Nh:]
725
728
  else: # For even N
726
729
  w_pad[:Nh] = w[:Nh]
727
730
  w_pad[Nh] = w[Nh] / 2
728
731
  w_pad[-Nh:] = w[-Nh:]
729
732
  w_pad[-Nh] = w[Nh] / 2
730
-
733
+
731
734
  return w_pad
732
735
 
733
- def compute_tfd(
734
- self,
735
- N: int,
736
- Nh: int,
737
- tfd: np.ndarray
738
- ):
736
+ def compute_tfd(self, N: int, Nh: int, tfd: np.ndarray):
739
737
  """
740
738
  Finalizes the time-frequency distribution computation.
741
739
 
@@ -756,25 +754,29 @@ class TimeFreqDistr:
756
754
  m = np.arange(0, Nh) # m = 0:(Nh-1)
757
755
  mb = np.arange(1, Nh) # mb = 1:(Nh-1)
758
756
 
759
- for n in range(0, N-1, 2): # for n=0:2:N-2
757
+ for n in range(0, N - 1, 2): # for n=0:2:N-2
760
758
  R_even_half = np.complex128(tfd[n, :Nh]) + 1j * np.complex128(tfd[n, Nh:])
761
- R_odd_half = np.complex128(tfd[n+1, :Nh]) + 1j * np.complex128(tfd[n+1, Nh:])
759
+ R_odd_half = np.complex128(tfd[n + 1, :Nh]) + 1j * np.complex128(
760
+ tfd[n + 1, Nh:]
761
+ )
762
762
 
763
- R_tslice_even = np.zeros(N, dtype=np.complex128)
763
+ R_tslice_even = np.zeros(N, dtype=np.complex128)
764
764
  R_tslice_odd = np.zeros(N, dtype=np.complex128)
765
765
 
766
766
  R_tslice_even[m] = R_even_half
767
767
  R_tslice_odd[m] = R_odd_half
768
768
 
769
- R_tslice_even[N-mb] = np.conj(R_even_half[mb])
770
- R_tslice_odd[N-mb] = np.conj(R_odd_half[mb])
771
-
769
+ R_tslice_even[N - mb] = np.conj(R_even_half[mb])
770
+ R_tslice_odd[N - mb] = np.conj(R_odd_half[mb])
771
+
772
772
  # Perform FFT to compute time slices
773
773
  tfd_time_slice = np.fft.fft(R_tslice_even + 1j * R_tslice_odd)
774
774
 
775
775
  tfd[n, :] = np.real(tfd_time_slice)
776
- tfd[n+1, :] = np.imag(tfd_time_slice)
776
+ tfd[n + 1, :] = np.imag(tfd_time_slice)
777
777
 
778
778
  tfd = tfd / N # Normalize the TFD
779
- tfd = tfd.transpose() # Transpose the TFD to have the time on the x-axis and frequency on the y-axis
779
+ tfd = (
780
+ tfd.transpose()
781
+ ) # Transpose the TFD to have the time on the x-axis and frequency on the y-axis
780
782
  return tfd