foscat 3.7.3__py3-none-any.whl → 3.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/alm.py CHANGED
@@ -97,8 +97,8 @@ class alm:
97
97
  self.ring_th(nside)
98
98
  self.ring_ph(nside)
99
99
  x = (-1j * np.arange(3 * nside)).reshape(1, 3 * nside)
100
- self.matrix_shift_ph[nside] = self.backend.bk_cast(
101
- self.backend.bk_exp(x * self.lph[nside].reshape(4 * nside - 1, 1))
100
+ self.matrix_shift_ph[nside] = self.backend.bk_exp(
101
+ self.backend.bk_cast(x * self.lph[nside].reshape(4 * nside - 1, 1))
102
102
  )
103
103
 
104
104
  self.lmax = 3 * nside - 1
@@ -129,9 +129,9 @@ class alm:
129
129
  - 0.5 * self.log(l + m)
130
130
  )
131
131
 
132
- self.A[nside, m] = self.backend.constant((aval))
133
- self.B[nside, m] = self.backend.constant((bval))
134
- self.ratio_mm[nside, m] = self.backend.constant(
132
+ self.A[nside, m] = self.backend.bk_constant((aval))
133
+ self.B[nside, m] = self.backend.bk_constant((bval))
134
+ self.ratio_mm[nside, m] = self.backend.bk_constant(
135
135
  np.sqrt(4 * np.pi) * np.expand_dims(np.exp(val), 1)
136
136
  )
137
137
  # Calcul de P_{mm}(x)
@@ -141,7 +141,7 @@ class alm:
141
141
  P_mm[m] = 1.0
142
142
  for m in range(3 * nside - 1):
143
143
  P_mm[m] = (0.5 - m % 2) * 2 * (1 - x**2) ** (m / 2)
144
- self.P_mm[nside] = self.backend.constant(P_mm)
144
+ self.P_mm[nside] = self.backend.bk_constant(P_mm)
145
145
 
146
146
  def init_Ys(self, s, nside):
147
147
 
@@ -167,8 +167,8 @@ class alm:
167
167
  vnorm = 1 / np.expand_dims(
168
168
  np.sqrt(2 * (np.arange(ell_max - m + 1) + m) + 1), 1
169
169
  )
170
- self.Yp[s, nside][m] = iplus[idx] * vnorm
171
- self.Ym[s, nside][m] = imoins[idx] * vnorm
170
+ self.Yp[s, nside][m] = self.backend.bk_cast(iplus[idx] * vnorm+0J)
171
+ self.Ym[s, nside][m] = self.backend.bk_cast(imoins[idx] * vnorm+0J)
172
172
 
173
173
  del iplus
174
174
  del imoins
@@ -224,7 +224,7 @@ class alm:
224
224
  result[0] = Pmm
225
225
 
226
226
  if m == lmax:
227
- return result * np.exp(ratio) * np.sqrt(4 * np.pi)
227
+ return result * np.exp(ratio) * np.sqrt(4 * np.pi)+0J
228
228
 
229
229
  # Étape 2 : Calcul de P_{l+1, m}(x)
230
230
  result[1] = x * (2 * m + 1) * result[0]
@@ -245,7 +245,7 @@ class alm:
245
245
  ratio[l - m - 1, 0] += self._log_limit_range
246
246
  ratio[l - m, 0] += self._log_limit_range
247
247
 
248
- return result * np.exp(ratio) * np.sqrt(4 * np.pi)
248
+ return result * np.exp(ratio) * np.sqrt(4 * np.pi)+0J
249
249
 
250
250
  # Calcul des P_{lm}(x) pour tout l inclus dans [m,lmax]
251
251
  def compute_legendre_m_old2(self, x, m, lmax, nside):
@@ -270,13 +270,6 @@ class alm:
270
270
  self.A[nside, m][l - m] * x * result[l - m - 1]
271
271
  - self.B[nside, m][l - m] * result[l - m - 2]
272
272
  )
273
- """
274
- if np.max(abs(result[l-m]))>self._limit_range:
275
- result[l-m-1]*= self._limit_range
276
- result[l-m]*= self._limit_range
277
- ratio[l-m-1]+= self._log_limit_range
278
- ratio[l-m]+= self._log_limit_range
279
- """
280
273
  result = self.backend.bk_reshape(
281
274
  self.backend.bk_concat([result[k] for k in range(lmax + 1 - m)], axis=0),
282
275
  [lmax + 1 - m, 4 * nside - 1],
@@ -425,13 +418,13 @@ class alm:
425
418
  r = self.backend.bk_rfft(val)
426
419
  if axis == 0:
427
420
  r_inv = self.backend.bk_reverse(
428
- self.backend.bk_conjugate(r[1:-1]), axis=axis
421
+ self.backend.bk_conjugate(r[...,1:-1]), axis=-1
429
422
  )
430
423
  else:
431
424
  r_inv = self.backend.bk_reverse(
432
- self.backend.bk_conjugate(r[:, 1:-1]), axis=axis
425
+ self.backend.bk_conjugate(r[..., 1:-1]), axis=-1
433
426
  )
434
- return self.backend.bk_concat([r, r_inv], axis=axis)
427
+ return self.backend.bk_concat([r, r_inv], axis=axis+1)
435
428
 
436
429
  def irfft2fft(self, val, N, axis=0):
437
430
  if axis == 0:
@@ -440,7 +433,9 @@ class alm:
440
433
  return self.backend.bk_irfft(val[:, 0 : N // 2 + 1])
441
434
 
442
435
  def comp_tf(self, im, nside, realfft=False):
443
-
436
+
437
+ #im is [Nimage,12*nside**2]
438
+
444
439
  self.shift_ph(nside)
445
440
  n = 0
446
441
 
@@ -449,65 +444,53 @@ class alm:
449
444
  N = 4 * (k + 1)
450
445
 
451
446
  if realfft:
452
- tmp = self.rfft2fft(im[n : n + N])
447
+ tmp = self.rfft2fft(im[:,n : n + N])
453
448
  else:
454
- tmp = self.backend.bk_fft(im[n : n + N])
449
+ tmp = self.backend.bk_fft(im[:,n : n + N])
455
450
 
456
- l_n = tmp.shape[0]
451
+ l_n = tmp.shape[1]
457
452
 
458
453
  if l_n < 3 * nside + 1:
459
454
  repeat_n = 3 * nside // l_n + 1
460
- tmp = self.backend.bk_tile(tmp, repeat_n, axis=0)
455
+ tmp = self.backend.bk_tile(tmp, repeat_n, axis=1)
461
456
 
462
- ft_im.append(tmp[0 : 3 * nside])
457
+ ft_im.append(tmp[:,None,0 : 3 * nside])
463
458
 
464
459
  n += N
465
- if nside > 1:
466
- result = self.backend.bk_reshape(
467
- self.backend.bk_concat(ft_im, axis=0), [nside - 1, 3 * nside]
468
- )
469
460
 
470
461
  N = 4 * nside * (2 * nside + 1)
471
- v = self.backend.bk_reshape(im[n : n + N], [2 * nside + 1, 4 * nside])
462
+ v = self.backend.bk_reshape(im[:,n : n + N], [im.shape[0],2 * nside + 1, 4 * nside])
472
463
  if realfft:
473
- v_fft = self.rfft2fft(v, axis=1)[:, : 3 * nside]
464
+ v_fft = self.rfft2fft(v, axis=1)[:, :, : 3 * nside]
474
465
  else:
475
- v_fft = self.backend.bk_fft(v)[:, : 3 * nside]
466
+ v_fft = self.backend.bk_fft(v)[:, :, : 3 * nside]
476
467
 
477
468
  n += N
469
+
470
+ ft_im.append(v_fft)
471
+
478
472
  if nside > 1:
479
- result = self.backend.bk_concat([result, v_fft], axis=0)
480
- else:
481
- result = v_fft
482
-
483
- if nside > 1:
484
- ft_im = []
485
473
  for k in range(nside - 1):
486
474
  N = 4 * (nside - 1 - k)
487
475
 
488
476
  if realfft:
489
- tmp = self.rfft2fft(im[n : n + N])[0:l_n]
477
+ tmp = self.rfft2fft(im[:,n : n + N])
490
478
  else:
491
- tmp = self.backend.bk_fft(im[n : n + N])[0:l_n]
479
+ tmp = self.backend.bk_fft(im[:,n : n + N])
492
480
 
493
- l_n = tmp.shape[0]
481
+ l_n = tmp.shape[1]
494
482
 
495
483
  if l_n < 3 * nside + 1:
496
484
  repeat_n = 3 * nside // l_n + 1
497
- tmp = self.backend.bk_tile(tmp, repeat_n, axis=0)
498
-
499
- ft_im.append(tmp[0 : 3 * nside])
485
+ tmp = self.backend.bk_tile(tmp, repeat_n, axis=1)
486
+
487
+ ft_im.append(tmp[:,None,0 : 3 * nside])
500
488
  n += N
501
-
502
- lastresult = self.backend.bk_reshape(
503
- self.backend.bk_concat(ft_im, axis=0), [nside - 1, 3 * nside]
504
- )
505
- return (
506
- self.backend.bk_concat([result, lastresult], axis=0)
507
- * self.matrix_shift_ph[nside]
508
- )
509
- else:
510
- return result * self.matrix_shift_ph[nside]
489
+
490
+ return (
491
+ self.backend.bk_concat(ft_im, axis=1)
492
+ * self.matrix_shift_ph[nside][None,:,:]
493
+ )
511
494
 
512
495
  def icomp_tf(self, i_im, nside, realfft=False):
513
496
 
@@ -565,7 +548,7 @@ class alm:
565
548
  else:
566
549
  return result
567
550
 
568
- def anafast(self, im, map2=None, nest=False, spin=2):
551
+ def anafast(self, im, map2=None, nest=False, spin=2,axes=0):
569
552
  """The `anafast` function computes the L1 and L2 norm power spectra.
570
553
 
571
554
  Currently, it is not optimized for single-pass computation due to the relatively inefficient computation of \(Y_{lm}\).
@@ -584,29 +567,51 @@ class alm:
584
567
  ordered as TT, EE, BB, TE, EB.TBanafast function computes L1 and L2 norm powerspctra.
585
568
 
586
569
  """
570
+ no_input_column = False
571
+
587
572
  i_im = self.backend.bk_cast(im)
588
573
  if map2 is not None:
589
574
  i_map2 = self.backend.bk_cast(map2)
590
575
 
591
576
  doT = True
592
- if len(i_im.shape) == 1: # nopol
593
- nside = int(np.sqrt(i_im.shape[0] // 12))
577
+
578
+ if len(i_im.shape)-axes == 1: # nopol
579
+ nside = int(np.sqrt(i_im.shape[axes] // 12))
594
580
  else:
595
- if i_im.shape[0] == 2:
581
+ if len(i_im.shape)-axes == 2:
596
582
  doT = False
597
- nside = int(np.sqrt(i_im.shape[1] // 12))
598
-
583
+ nside = int(np.sqrt(i_im.shape[axes+1] // 12))
584
+ do_all_pol=False
585
+ if i_im.shape[axes]==3:
586
+ do_all_pol=True
587
+
599
588
  self.shift_ph(nside)
600
589
 
601
- if doT: # nopol
602
- if len(i_im.shape) == 2: # pol
603
- l_im = i_im[0]
604
- if map2 is not None:
605
- l_map2 = i_map2[0]
606
- else:
607
- l_im = i_im
608
- if map2 is not None:
609
- l_map2 = i_map2
590
+ if doT or do_all_pol:
591
+ if len(i_im.shape) == 1 + int(do_all_pol):# no pol if 1 all pol if 2
592
+ if do_all_pol:
593
+ l_im = i_im[None,0,...]
594
+ if map2 is not None:
595
+ l_map2 = i_map2[None,0,...]
596
+ else:
597
+ l_im = i_im[None,...]
598
+ if map2 is not None:
599
+ l_map2 = i_map2[None,...]
600
+ no_input_column = True
601
+ N_image=1
602
+
603
+ else:
604
+ if do_all_pol:
605
+ l_im = i_im[:,0]
606
+ if map2 is not None:
607
+ l_map2 = i_map2[:,0]
608
+ N_image=i_im.shape[0]
609
+
610
+ else:
611
+ l_im = i_im
612
+ if map2 is not None:
613
+ l_map2 = i_map2
614
+ N_image=i_im.shape[0]
610
615
 
611
616
  if nest:
612
617
  idx = hp.ring2nest(nside, np.arange(12 * nside**2))
@@ -622,7 +627,7 @@ class alm:
622
627
  ft_im = self.comp_tf(l_im, nside, realfft=True)
623
628
  if map2 is not None:
624
629
  ft_im2 = self.comp_tf(l_map2, nside, realfft=True)
625
-
630
+
626
631
  lth = self.ring_th(nside)
627
632
 
628
633
  co_th = np.cos(lth)
@@ -634,67 +639,80 @@ class alm:
634
639
  dt2 = 0
635
640
  dt3 = 0
636
641
  dt4 = 0
637
- if len(i_im.shape) == 2: # nopol
642
+ if not doT: # polarize case
638
643
 
639
644
  self.init_Ys(spin, nside)
645
+
646
+ if len(i_im.shape) == 2:
647
+ l_im = i_im[None,:,:]
648
+ if map2 is not None:
649
+ l_map2 = i_map2[None,:,:]
650
+ no_input_column = True
651
+ N_image=1
652
+ else:
653
+ l_im = i_im
654
+ if map2 is not None:
655
+ l_map2 = i_map2
656
+ N_image=i_im.shape[0]
640
657
 
641
658
  if nest:
642
659
  idx = hp.ring2nest(nside, np.arange(12 * nside**2))
643
- l_Q = self.backend.bk_gather(i_im[int(doT)], idx)
644
- l_U = self.backend.bk_gather(i_im[1 + int(doT)], idx)
660
+ l_Q = self.backend.bk_gather(l_im[:,int(do_all_pol)], idx)
661
+ l_U = self.backend.bk_gather(l_im[:,1 + int(do_all_pol)], idx)
645
662
  ft_im_Pp = self.comp_tf(self.backend.bk_complex(l_Q, l_U), nside)
646
663
  ft_im_Pm = self.comp_tf(self.backend.bk_complex(l_Q, -l_U), nside)
647
664
  if map2 is not None:
648
- l_Q = self.backend.bk_gather(i_map2[int(doT)], idx)
649
- l_U = self.backend.bk_gather(i_map2[1 + int(doT)], idx)
665
+ l_Q = self.backend.bk_gather(l_map2[:,int(do_all_pol)], idx)
666
+ l_U = self.backend.bk_gather(l_map2[:,1 + int(do_all_pol)], idx)
650
667
  ft_im2_Pp = self.comp_tf(self.backend.bk_complex(l_Q, l_U), nside)
651
668
  ft_im2_Pm = self.comp_tf(self.backend.bk_complex(l_Q, -l_U), nside)
652
669
  else:
653
670
  ft_im_Pp = self.comp_tf(
654
- self.backend.bk_complex(i_im[int(doT)], i_im[1 + int(doT)]), nside
671
+ self.backend.bk_complex(l_im[:,int(do_all_pol)], l_im[:,1 + int(do_all_pol)]), nside
655
672
  )
656
673
  ft_im_Pm = self.comp_tf(
657
- self.backend.bk_complex(i_im[int(doT)], -i_im[1 + int(doT)]), nside
674
+ self.backend.bk_complex(l_im[:,int(do_all_pol)], -l_im[:,1 + int(do_all_pol)]), nside
658
675
  )
659
676
  if map2 is not None:
660
677
  ft_im2_Pp = self.comp_tf(
661
- self.backend.bk_complex(i_map2[int(doT)], i_map2[1 + int(doT)]),
678
+ self.backend.bk_complex(l_map2[:,int(do_all_pol)], l_map2[:,1 + int(do_all_pol)]),
662
679
  nside,
663
680
  )
664
681
  ft_im2_Pm = self.comp_tf(
665
682
  self.backend.bk_complex(
666
- i_map2[int(doT)], -i_map2[1 + int(doT)]
683
+ l_map2[:,int(doT)], -l_map2[:,1 + int(do_all_pol)]
667
684
  ),
668
685
  nside,
669
686
  )
670
-
687
+
688
+ l_cl=[]
671
689
  for m in range(lmax + 1):
672
690
 
673
- plm = self.compute_legendre_m(co_th, m, 3 * nside - 1, nside) / (
674
- 12 * nside**2
675
- )
691
+ plm = self.backend.bk_cast(
692
+ self.compute_legendre_m(co_th, m, 3 * nside - 1, nside) / (
693
+ 12 * nside**2
694
+ )
695
+ )
676
696
 
677
- if doT:
678
- tmp = self.backend.bk_reduce_sum(plm * ft_im[:, m], 1)
697
+ if doT or do_all_pol:
698
+ tmp = self.backend.bk_reduce_sum(plm[None,:,:] * ft_im[:,None,:, m], 2)
679
699
 
680
700
  if map2 is not None:
681
- tmp2 = self.backend.bk_reduce_sum(plm * ft_im2[:, m], 1)
701
+ tmp2 = self.backend.bk_reduce_sum(plm[None,:,:] * ft_im2[:,None, :, m], 2)
682
702
  else:
683
703
  tmp2 = tmp
684
704
 
685
- if len(i_im.shape) == 2: # pol
705
+ if not doT: # polarize case
686
706
  plmp = self.Yp[spin, nside][m]
687
707
  plmm = self.Ym[spin, nside][m]
688
-
689
- tmpp = self.backend.bk_reduce_sum(plmp * ft_im_Pp[:, m], 1)
690
- tmpm = self.backend.bk_reduce_sum(plmm * ft_im_Pm[:, m], 1)
691
-
708
+ tmpp = self.backend.bk_reduce_sum(plmp[None,:,:] * ft_im_Pp[:,None, :,m], 2)
709
+ tmpm = self.backend.bk_reduce_sum(plmm[None,:,:] * ft_im_Pm[:,None, :,m], 2)
692
710
  almE = -(tmpp + tmpm) / 2.0
693
711
  almB = (tmpp - tmpm) / (2j)
694
712
 
695
713
  if map2 is not None:
696
- tmpp2 = self.backend.bk_reduce_sum(plmp * ft_im2_Pp[:, m], 1)
697
- tmpm2 = self.backend.bk_reduce_sum(plmm * ft_im2_Pm[:, m], 1)
714
+ tmpp2 = self.backend.bk_reduce_sum(plmp[None,:,:] * ft_im2_Pp[:,None,:, m], 2)
715
+ tmpm2 = self.backend.bk_reduce_sum(plmm[None,:,:] * ft_im2_Pm[:,None,:, m], 2)
698
716
 
699
717
  almE2 = -(tmpp2 + tmpm2) / 2.0
700
718
  almB2 = (tmpp2 - tmpm2) / (2j)
@@ -702,7 +720,7 @@ class alm:
702
720
  almE2 = almE
703
721
  almB2 = almB
704
722
 
705
- if doT:
723
+ if do_all_pol:
706
724
  tmpTT = self.backend.bk_real(
707
725
  (tmp * self.backend.bk_conjugate(tmp2))
708
726
  )
@@ -725,7 +743,7 @@ class alm:
725
743
  )
726
744
  ) / 2
727
745
 
728
- if doT:
746
+ if do_all_pol:
729
747
  tmpTE = (
730
748
  tmpTE
731
749
  + self.backend.bk_real(
@@ -740,59 +758,68 @@ class alm:
740
758
  ) / 2
741
759
 
742
760
  if m == 0:
743
- if doT:
744
- l_cl = self.backend.bk_concat(
745
- [tmpTT, tmpEE, tmpBB, tmpTE, tmpEB, tmpTB], 0
746
- )
761
+ if do_all_pol:
762
+ l_cl.append(tmpTT)
763
+ l_cl.append(tmpEE)
764
+ l_cl.append(tmpBB)
765
+ l_cl.append(tmpTE)
766
+ l_cl.append(tmpEB)
767
+ l_cl.append(tmpTB)
747
768
  else:
748
- l_cl = self.backend.bk_concat([tmpEE, tmpBB, tmpEB], 0)
769
+ l_cl.append(tmpEE)
770
+ l_cl.append(tmpBB)
771
+ l_cl.append(tmpEB)
772
+
749
773
  else:
750
774
  offset_tensor = self.backend.bk_zeros(
751
- (m), dtype=self.backend.all_bk_type
775
+ (N_image,m), dtype=self.backend.all_bk_type
752
776
  )
753
- if doT:
754
- l_cl = self.backend.bk_concat(
755
- [
756
- self.backend.bk_concat([offset_tensor, tmpTT], axis=0),
757
- self.backend.bk_concat([offset_tensor, tmpEE], axis=0),
758
- self.backend.bk_concat([offset_tensor, tmpBB], axis=0),
759
- self.backend.bk_concat([offset_tensor, tmpTE], axis=0),
760
- self.backend.bk_concat([offset_tensor, tmpEB], axis=0),
761
- self.backend.bk_concat([offset_tensor, tmpTB], axis=0),
762
- ],
763
- axis=0,
764
- )
777
+ if do_all_pol:
778
+ l_cl.append(offset_tensor)
779
+ l_cl.append(2*tmpTT)
780
+ l_cl.append(offset_tensor)
781
+ l_cl.append(2*tmpEE)
782
+ l_cl.append(offset_tensor)
783
+ l_cl.append(2*tmpBB)
784
+ l_cl.append(offset_tensor)
785
+ l_cl.append(2*tmpTE)
786
+ l_cl.append(offset_tensor)
787
+ l_cl.append(2*tmpEB)
788
+ l_cl.append(offset_tensor)
789
+ l_cl.append(2*tmpTB)
765
790
  else:
766
- l_cl = self.backend.bk_concat(
767
- [
768
- self.backend.bk_concat([offset_tensor, tmpEE], axis=0),
769
- self.backend.bk_concat([offset_tensor, tmpBB], axis=0),
770
- self.backend.bk_concat([offset_tensor, tmpEB], axis=0),
771
- ],
772
- axis=0,
773
- )
774
-
775
- if doT:
776
- l_cl = self.backend.bk_reshape(l_cl, [6, lmax + 1])
777
- else:
778
- l_cl = self.backend.bk_reshape(l_cl, [3, lmax + 1])
791
+ l_cl.append(offset_tensor)
792
+ l_cl.append(2*tmpEE)
793
+ l_cl.append(offset_tensor)
794
+ l_cl.append(2*tmpBB)
795
+ l_cl.append(offset_tensor)
796
+ l_cl.append(2*tmpEB)
779
797
  else:
780
798
  tmp = self.backend.bk_real((tmp * self.backend.bk_conjugate(tmp2)))
781
799
  if m == 0:
782
- l_cl = tmp
800
+ l_cl.append(tmp)
783
801
  else:
784
802
  offset_tensor = self.backend.bk_zeros(
785
- (m), dtype=self.backend.all_bk_type
803
+ (N_image,m), dtype=self.backend.all_bk_type
786
804
  )
787
- l_cl = self.backend.bk_concat([offset_tensor, tmp], axis=0)
788
-
789
- if cl2 is None:
790
- cl2 = l_cl
805
+ l_cl.append(offset_tensor)
806
+ l_cl.append(2*tmp)
807
+
808
+ l_cl=self.backend.bk_concat(l_cl,1)
809
+
810
+ if doT:
811
+ cl2 = self.backend.bk_reshape(l_cl,[N_image,lmax+1,lmax+1])
812
+ cl2 = self.backend.bk_reduce_sum(cl2,1)
813
+ else:
814
+ if do_all_pol:
815
+ cl2 = self.backend.bk_reshape(l_cl,[N_image,lmax+1,6,lmax+1])
791
816
  else:
792
- cl2 += 2 * l_cl
793
-
794
- # cl2=cl2*(4*np.pi) #self.backend.bk_sqrt(self.backend.bk_cast(4*np.pi)) #(2*np.arange(cl2.shape[0])+1)))
795
-
817
+ cl2 = self.backend.bk_reshape(l_cl,[N_image,lmax+1,3,lmax+1])
818
+ cl2 = self.backend.bk_reduce_sum(cl2,1)
819
+
820
+ if no_input_column:
821
+ cl2=cl2[0]
822
+
796
823
  cl2_l1 = self.backend.bk_L1(cl2)
797
824
 
798
825
  return cl2, cl2_l1
foscat/backend.py CHANGED
@@ -927,6 +927,7 @@ class foscat_backend:
927
927
  if self.BACKEND == self.TORCH:
928
928
  if isinstance(data, np.ndarray):
929
929
  return data.reshape(shape)
930
+ return data.view(shape)
930
931
 
931
932
  return self.backend.reshape(data, shape)
932
933
 
@@ -1027,7 +1028,9 @@ class foscat_backend:
1027
1028
 
1028
1029
  def bk_fftn(self, data,dim=None):
1029
1030
  if self.BACKEND == self.TENSORFLOW:
1030
- return self.backend.signal.fftn(data)
1031
+ #Equivalent of torch.fft.fftn(x, dim=dims) in TensorFlow
1032
+ x=self.bk_complex(data,0*data)
1033
+ return self.backend.signal.fftnd(x, fft_length=tuple(x.shape[d] for d in dim),axes=dim)
1031
1034
  if self.BACKEND == self.TORCH:
1032
1035
  return self.backend.fft.fftn(data,dim=dim)
1033
1036
  if self.BACKEND == self.NUMPY:
@@ -1035,7 +1038,7 @@ class foscat_backend:
1035
1038
 
1036
1039
  def bk_ifftn(self, data,dim=None,norm=None):
1037
1040
  if self.BACKEND == self.TENSORFLOW:
1038
- return self.backend.signal.ifftn(data)
1041
+ return self.backend.signal.ifftnd(data,fft_length=tuple(data.shape[d] for d in dim),axes=dim,norm=norm)
1039
1042
  if self.BACKEND == self.TORCH:
1040
1043
  return self.backend.fft.ifftn(data,dim=dim,norm=norm)
1041
1044
  if self.BACKEND == self.NUMPY: