foscat 3.7.3__py3-none-any.whl → 3.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/alm.py CHANGED
@@ -97,8 +97,8 @@ class alm:
97
97
  self.ring_th(nside)
98
98
  self.ring_ph(nside)
99
99
  x = (-1j * np.arange(3 * nside)).reshape(1, 3 * nside)
100
- self.matrix_shift_ph[nside] = self.backend.bk_cast(
101
- self.backend.bk_exp(x * self.lph[nside].reshape(4 * nside - 1, 1))
100
+ self.matrix_shift_ph[nside] = self.backend.bk_exp(
101
+ self.backend.bk_cast(x * self.lph[nside].reshape(4 * nside - 1, 1))
102
102
  )
103
103
 
104
104
  self.lmax = 3 * nside - 1
@@ -129,9 +129,9 @@ class alm:
129
129
  - 0.5 * self.log(l + m)
130
130
  )
131
131
 
132
- self.A[nside, m] = self.backend.constant((aval))
133
- self.B[nside, m] = self.backend.constant((bval))
134
- self.ratio_mm[nside, m] = self.backend.constant(
132
+ self.A[nside, m] = self.backend.bk_constant((aval))
133
+ self.B[nside, m] = self.backend.bk_constant((bval))
134
+ self.ratio_mm[nside, m] = self.backend.bk_constant(
135
135
  np.sqrt(4 * np.pi) * np.expand_dims(np.exp(val), 1)
136
136
  )
137
137
  # Calcul de P_{mm}(x)
@@ -141,7 +141,7 @@ class alm:
141
141
  P_mm[m] = 1.0
142
142
  for m in range(3 * nside - 1):
143
143
  P_mm[m] = (0.5 - m % 2) * 2 * (1 - x**2) ** (m / 2)
144
- self.P_mm[nside] = self.backend.constant(P_mm)
144
+ self.P_mm[nside] = self.backend.bk_constant(P_mm)
145
145
 
146
146
  def init_Ys(self, s, nside):
147
147
 
@@ -167,8 +167,8 @@ class alm:
167
167
  vnorm = 1 / np.expand_dims(
168
168
  np.sqrt(2 * (np.arange(ell_max - m + 1) + m) + 1), 1
169
169
  )
170
- self.Yp[s, nside][m] = iplus[idx] * vnorm
171
- self.Ym[s, nside][m] = imoins[idx] * vnorm
170
+ self.Yp[s, nside][m] = self.backend.bk_cast(iplus[idx] * vnorm+0J)
171
+ self.Ym[s, nside][m] = self.backend.bk_cast(imoins[idx] * vnorm+0J)
172
172
 
173
173
  del iplus
174
174
  del imoins
@@ -224,7 +224,7 @@ class alm:
224
224
  result[0] = Pmm
225
225
 
226
226
  if m == lmax:
227
- return result * np.exp(ratio) * np.sqrt(4 * np.pi)
227
+ return result * np.exp(ratio) * np.sqrt(4 * np.pi)+0J
228
228
 
229
229
  # Étape 2 : Calcul de P_{l+1, m}(x)
230
230
  result[1] = x * (2 * m + 1) * result[0]
@@ -245,7 +245,7 @@ class alm:
245
245
  ratio[l - m - 1, 0] += self._log_limit_range
246
246
  ratio[l - m, 0] += self._log_limit_range
247
247
 
248
- return result * np.exp(ratio) * np.sqrt(4 * np.pi)
248
+ return result * np.exp(ratio) * np.sqrt(4 * np.pi)+0J
249
249
 
250
250
  # Calcul des P_{lm}(x) pour tout l inclus dans [m,lmax]
251
251
  def compute_legendre_m_old2(self, x, m, lmax, nside):
@@ -270,13 +270,6 @@ class alm:
270
270
  self.A[nside, m][l - m] * x * result[l - m - 1]
271
271
  - self.B[nside, m][l - m] * result[l - m - 2]
272
272
  )
273
- """
274
- if np.max(abs(result[l-m]))>self._limit_range:
275
- result[l-m-1]*= self._limit_range
276
- result[l-m]*= self._limit_range
277
- ratio[l-m-1]+= self._log_limit_range
278
- ratio[l-m]+= self._log_limit_range
279
- """
280
273
  result = self.backend.bk_reshape(
281
274
  self.backend.bk_concat([result[k] for k in range(lmax + 1 - m)], axis=0),
282
275
  [lmax + 1 - m, 4 * nside - 1],
@@ -425,13 +418,13 @@ class alm:
425
418
  r = self.backend.bk_rfft(val)
426
419
  if axis == 0:
427
420
  r_inv = self.backend.bk_reverse(
428
- self.backend.bk_conjugate(r[1:-1]), axis=axis
421
+ self.backend.bk_conjugate(r[...,1:-1]), axis=-1
429
422
  )
430
423
  else:
431
424
  r_inv = self.backend.bk_reverse(
432
- self.backend.bk_conjugate(r[:, 1:-1]), axis=axis
425
+ self.backend.bk_conjugate(r[..., 1:-1]), axis=-1
433
426
  )
434
- return self.backend.bk_concat([r, r_inv], axis=axis)
427
+ return self.backend.bk_concat([r, r_inv], axis=axis+1)
435
428
 
436
429
  def irfft2fft(self, val, N, axis=0):
437
430
  if axis == 0:
@@ -440,7 +433,9 @@ class alm:
440
433
  return self.backend.bk_irfft(val[:, 0 : N // 2 + 1])
441
434
 
442
435
  def comp_tf(self, im, nside, realfft=False):
443
-
436
+
437
+ #im is [Nimage,12*nside**2]
438
+
444
439
  self.shift_ph(nside)
445
440
  n = 0
446
441
 
@@ -449,65 +444,53 @@ class alm:
449
444
  N = 4 * (k + 1)
450
445
 
451
446
  if realfft:
452
- tmp = self.rfft2fft(im[n : n + N])
447
+ tmp = self.rfft2fft(im[:,n : n + N])
453
448
  else:
454
- tmp = self.backend.bk_fft(im[n : n + N])
449
+ tmp = self.backend.bk_fft(im[:,n : n + N])
455
450
 
456
- l_n = tmp.shape[0]
451
+ l_n = tmp.shape[1]
457
452
 
458
453
  if l_n < 3 * nside + 1:
459
454
  repeat_n = 3 * nside // l_n + 1
460
- tmp = self.backend.bk_tile(tmp, repeat_n, axis=0)
455
+ tmp = self.backend.bk_tile(tmp, repeat_n, axis=1)
461
456
 
462
- ft_im.append(tmp[0 : 3 * nside])
457
+ ft_im.append(tmp[:,None,0 : 3 * nside])
463
458
 
464
459
  n += N
465
- if nside > 1:
466
- result = self.backend.bk_reshape(
467
- self.backend.bk_concat(ft_im, axis=0), [nside - 1, 3 * nside]
468
- )
469
460
 
470
461
  N = 4 * nside * (2 * nside + 1)
471
- v = self.backend.bk_reshape(im[n : n + N], [2 * nside + 1, 4 * nside])
462
+ v = self.backend.bk_reshape(im[:,n : n + N], [im.shape[0],2 * nside + 1, 4 * nside])
472
463
  if realfft:
473
- v_fft = self.rfft2fft(v, axis=1)[:, : 3 * nside]
464
+ v_fft = self.rfft2fft(v, axis=1)[:, :, : 3 * nside]
474
465
  else:
475
- v_fft = self.backend.bk_fft(v)[:, : 3 * nside]
466
+ v_fft = self.backend.bk_fft(v)[:, :, : 3 * nside]
476
467
 
477
468
  n += N
469
+
470
+ ft_im.append(v_fft)
471
+
478
472
  if nside > 1:
479
- result = self.backend.bk_concat([result, v_fft], axis=0)
480
- else:
481
- result = v_fft
482
-
483
- if nside > 1:
484
- ft_im = []
485
473
  for k in range(nside - 1):
486
474
  N = 4 * (nside - 1 - k)
487
475
 
488
476
  if realfft:
489
- tmp = self.rfft2fft(im[n : n + N])[0:l_n]
477
+ tmp = self.rfft2fft(im[:,n : n + N])
490
478
  else:
491
- tmp = self.backend.bk_fft(im[n : n + N])[0:l_n]
479
+ tmp = self.backend.bk_fft(im[:,n : n + N])
492
480
 
493
- l_n = tmp.shape[0]
481
+ l_n = tmp.shape[1]
494
482
 
495
483
  if l_n < 3 * nside + 1:
496
484
  repeat_n = 3 * nside // l_n + 1
497
- tmp = self.backend.bk_tile(tmp, repeat_n, axis=0)
498
-
499
- ft_im.append(tmp[0 : 3 * nside])
485
+ tmp = self.backend.bk_tile(tmp, repeat_n, axis=1)
486
+
487
+ ft_im.append(tmp[:,None,0 : 3 * nside])
500
488
  n += N
501
-
502
- lastresult = self.backend.bk_reshape(
503
- self.backend.bk_concat(ft_im, axis=0), [nside - 1, 3 * nside]
504
- )
505
- return (
506
- self.backend.bk_concat([result, lastresult], axis=0)
507
- * self.matrix_shift_ph[nside]
508
- )
509
- else:
510
- return result * self.matrix_shift_ph[nside]
489
+
490
+ return (
491
+ self.backend.bk_concat(ft_im, axis=1)
492
+ * self.matrix_shift_ph[nside][None,:,:]
493
+ )
511
494
 
512
495
  def icomp_tf(self, i_im, nside, realfft=False):
513
496
 
@@ -565,7 +548,7 @@ class alm:
565
548
  else:
566
549
  return result
567
550
 
568
- def anafast(self, im, map2=None, nest=False, spin=2):
551
+ def anafast(self, im, map2=None, nest=False, spin=2,axes=0):
569
552
  """The `anafast` function computes the L1 and L2 norm power spectra.
570
553
 
571
554
  Currently, it is not optimized for single-pass computation due to the relatively inefficient computation of \(Y_{lm}\).
@@ -584,45 +567,66 @@ class alm:
584
567
  ordered as TT, EE, BB, TE, EB.TBanafast function computes L1 and L2 norm powerspctra.
585
568
 
586
569
  """
570
+ no_input_column = False
571
+
587
572
  i_im = self.backend.bk_cast(im)
588
573
  if map2 is not None:
589
574
  i_map2 = self.backend.bk_cast(map2)
590
575
 
591
576
  doT = True
592
- if len(i_im.shape) == 1: # nopol
593
- nside = int(np.sqrt(i_im.shape[0] // 12))
577
+
578
+ if len(i_im.shape)-axes == 1: # nopol
579
+ nside = int(np.sqrt(i_im.shape[axes] // 12))
594
580
  else:
595
- if i_im.shape[0] == 2:
581
+ if len(i_im.shape)-axes == 2:
596
582
  doT = False
597
- nside = int(np.sqrt(i_im.shape[1] // 12))
598
-
583
+ nside = int(np.sqrt(i_im.shape[axes+1] // 12))
584
+ do_all_pol=False
585
+ if i_im.shape[axes]==3:
586
+ do_all_pol=True
587
+
599
588
  self.shift_ph(nside)
600
589
 
601
- if doT: # nopol
602
- if len(i_im.shape) == 2: # pol
603
- l_im = i_im[0]
604
- if map2 is not None:
605
- l_map2 = i_map2[0]
606
- else:
607
- l_im = i_im
608
- if map2 is not None:
609
- l_map2 = i_map2
590
+ if doT or do_all_pol:
591
+ if len(i_im.shape) == 1 + int(do_all_pol):# no pol if 1 all pol if 2
592
+ if do_all_pol:
593
+ l_im = i_im[None,0,...]
594
+ if map2 is not None:
595
+ l_map2 = i_map2[None,0,...]
596
+ else:
597
+ l_im = i_im[None,...]
598
+ if map2 is not None:
599
+ l_map2 = i_map2[None,...]
600
+ no_input_column = True
601
+ N_image=1
602
+
603
+ else:
604
+ if do_all_pol:
605
+ l_im = i_im[:,0]
606
+ if map2 is not None:
607
+ l_map2 = i_map2[:,0]
608
+ N_image=i_im.shape[0]
609
+
610
+ else:
611
+ l_im = i_im
612
+ if map2 is not None:
613
+ l_map2 = i_map2
614
+ N_image=i_im.shape[0]
610
615
 
611
616
  if nest:
612
617
  idx = hp.ring2nest(nside, np.arange(12 * nside**2))
613
- if len(i_im.shape) == 1: # nopol
614
- ft_im = self.comp_tf(
615
- self.backend.bk_gather(l_im, idx), nside, realfft=True
618
+ ft_im = self.comp_tf(
619
+ self.backend.bk_gather(l_im, idx, axis=1), nside, realfft=True
620
+ )
621
+ if map2 is not None:
622
+ ft_im2 = self.comp_tf(
623
+ self.backend.bk_gather(l_map2, idx, axis=1), nside, realfft=True
616
624
  )
617
- if map2 is not None:
618
- ft_im2 = self.comp_tf(
619
- self.backend.bk_gather(l_map2, idx), nside, realfft=True
620
- )
621
625
  else:
622
626
  ft_im = self.comp_tf(l_im, nside, realfft=True)
623
627
  if map2 is not None:
624
628
  ft_im2 = self.comp_tf(l_map2, nside, realfft=True)
625
-
629
+
626
630
  lth = self.ring_th(nside)
627
631
 
628
632
  co_th = np.cos(lth)
@@ -634,67 +638,80 @@ class alm:
634
638
  dt2 = 0
635
639
  dt3 = 0
636
640
  dt4 = 0
637
- if len(i_im.shape) == 2: # nopol
641
+ if not doT: # polarize case
638
642
 
639
643
  self.init_Ys(spin, nside)
644
+
645
+ if len(i_im.shape) == 2:
646
+ l_im = i_im[None,:,:]
647
+ if map2 is not None:
648
+ l_map2 = i_map2[None,:,:]
649
+ no_input_column = True
650
+ N_image=1
651
+ else:
652
+ l_im = i_im
653
+ if map2 is not None:
654
+ l_map2 = i_map2
655
+ N_image=i_im.shape[0]
640
656
 
641
657
  if nest:
642
658
  idx = hp.ring2nest(nside, np.arange(12 * nside**2))
643
- l_Q = self.backend.bk_gather(i_im[int(doT)], idx)
644
- l_U = self.backend.bk_gather(i_im[1 + int(doT)], idx)
659
+ l_Q = self.backend.bk_gather(l_im[:,int(do_all_pol)], idx,axis=1)
660
+ l_U = self.backend.bk_gather(l_im[:,1 + int(do_all_pol)], idx,axis=1)
645
661
  ft_im_Pp = self.comp_tf(self.backend.bk_complex(l_Q, l_U), nside)
646
662
  ft_im_Pm = self.comp_tf(self.backend.bk_complex(l_Q, -l_U), nside)
647
663
  if map2 is not None:
648
- l_Q = self.backend.bk_gather(i_map2[int(doT)], idx)
649
- l_U = self.backend.bk_gather(i_map2[1 + int(doT)], idx)
664
+ l_Q = self.backend.bk_gather(l_map2[:,int(do_all_pol)], idx,axis=1)
665
+ l_U = self.backend.bk_gather(l_map2[:,1 + int(do_all_pol)], idx,axis=1)
650
666
  ft_im2_Pp = self.comp_tf(self.backend.bk_complex(l_Q, l_U), nside)
651
667
  ft_im2_Pm = self.comp_tf(self.backend.bk_complex(l_Q, -l_U), nside)
652
668
  else:
653
669
  ft_im_Pp = self.comp_tf(
654
- self.backend.bk_complex(i_im[int(doT)], i_im[1 + int(doT)]), nside
670
+ self.backend.bk_complex(l_im[:,int(do_all_pol)], l_im[:,1 + int(do_all_pol)]), nside
655
671
  )
656
672
  ft_im_Pm = self.comp_tf(
657
- self.backend.bk_complex(i_im[int(doT)], -i_im[1 + int(doT)]), nside
673
+ self.backend.bk_complex(l_im[:,int(do_all_pol)], -l_im[:,1 + int(do_all_pol)]), nside
658
674
  )
659
675
  if map2 is not None:
660
676
  ft_im2_Pp = self.comp_tf(
661
- self.backend.bk_complex(i_map2[int(doT)], i_map2[1 + int(doT)]),
677
+ self.backend.bk_complex(l_map2[:,int(do_all_pol)], l_map2[:,1 + int(do_all_pol)]),
662
678
  nside,
663
679
  )
664
680
  ft_im2_Pm = self.comp_tf(
665
681
  self.backend.bk_complex(
666
- i_map2[int(doT)], -i_map2[1 + int(doT)]
682
+ l_map2[:,int(doT)], -l_map2[:,1 + int(do_all_pol)]
667
683
  ),
668
684
  nside,
669
685
  )
670
-
686
+
687
+ l_cl=[]
671
688
  for m in range(lmax + 1):
672
689
 
673
- plm = self.compute_legendre_m(co_th, m, 3 * nside - 1, nside) / (
674
- 12 * nside**2
675
- )
690
+ plm = self.backend.bk_cast(
691
+ self.compute_legendre_m(co_th, m, 3 * nside - 1, nside) / (
692
+ 12 * nside**2
693
+ )
694
+ )
676
695
 
677
- if doT:
678
- tmp = self.backend.bk_reduce_sum(plm * ft_im[:, m], 1)
696
+ if doT or do_all_pol:
697
+ tmp = self.backend.bk_reduce_sum(plm[None,:,:] * ft_im[:,None,:, m], 2)
679
698
 
680
699
  if map2 is not None:
681
- tmp2 = self.backend.bk_reduce_sum(plm * ft_im2[:, m], 1)
700
+ tmp2 = self.backend.bk_reduce_sum(plm[None,:,:] * ft_im2[:,None, :, m], 2)
682
701
  else:
683
702
  tmp2 = tmp
684
703
 
685
- if len(i_im.shape) == 2: # pol
704
+ if not doT: # polarize case
686
705
  plmp = self.Yp[spin, nside][m]
687
706
  plmm = self.Ym[spin, nside][m]
688
-
689
- tmpp = self.backend.bk_reduce_sum(plmp * ft_im_Pp[:, m], 1)
690
- tmpm = self.backend.bk_reduce_sum(plmm * ft_im_Pm[:, m], 1)
691
-
707
+ tmpp = self.backend.bk_reduce_sum(plmp[None,:,:] * ft_im_Pp[:,None, :,m], 2)
708
+ tmpm = self.backend.bk_reduce_sum(plmm[None,:,:] * ft_im_Pm[:,None, :,m], 2)
692
709
  almE = -(tmpp + tmpm) / 2.0
693
710
  almB = (tmpp - tmpm) / (2j)
694
711
 
695
712
  if map2 is not None:
696
- tmpp2 = self.backend.bk_reduce_sum(plmp * ft_im2_Pp[:, m], 1)
697
- tmpm2 = self.backend.bk_reduce_sum(plmm * ft_im2_Pm[:, m], 1)
713
+ tmpp2 = self.backend.bk_reduce_sum(plmp[None,:,:] * ft_im2_Pp[:,None,:, m], 2)
714
+ tmpm2 = self.backend.bk_reduce_sum(plmm[None,:,:] * ft_im2_Pm[:,None,:, m], 2)
698
715
 
699
716
  almE2 = -(tmpp2 + tmpm2) / 2.0
700
717
  almB2 = (tmpp2 - tmpm2) / (2j)
@@ -702,7 +719,7 @@ class alm:
702
719
  almE2 = almE
703
720
  almB2 = almB
704
721
 
705
- if doT:
722
+ if do_all_pol:
706
723
  tmpTT = self.backend.bk_real(
707
724
  (tmp * self.backend.bk_conjugate(tmp2))
708
725
  )
@@ -725,7 +742,7 @@ class alm:
725
742
  )
726
743
  ) / 2
727
744
 
728
- if doT:
745
+ if do_all_pol:
729
746
  tmpTE = (
730
747
  tmpTE
731
748
  + self.backend.bk_real(
@@ -740,59 +757,68 @@ class alm:
740
757
  ) / 2
741
758
 
742
759
  if m == 0:
743
- if doT:
744
- l_cl = self.backend.bk_concat(
745
- [tmpTT, tmpEE, tmpBB, tmpTE, tmpEB, tmpTB], 0
746
- )
760
+ if do_all_pol:
761
+ l_cl.append(tmpTT)
762
+ l_cl.append(tmpEE)
763
+ l_cl.append(tmpBB)
764
+ l_cl.append(tmpTE)
765
+ l_cl.append(tmpEB)
766
+ l_cl.append(tmpTB)
747
767
  else:
748
- l_cl = self.backend.bk_concat([tmpEE, tmpBB, tmpEB], 0)
768
+ l_cl.append(tmpEE)
769
+ l_cl.append(tmpBB)
770
+ l_cl.append(tmpEB)
771
+
749
772
  else:
750
773
  offset_tensor = self.backend.bk_zeros(
751
- (m), dtype=self.backend.all_bk_type
774
+ (N_image,m), dtype=self.backend.all_bk_type
752
775
  )
753
- if doT:
754
- l_cl = self.backend.bk_concat(
755
- [
756
- self.backend.bk_concat([offset_tensor, tmpTT], axis=0),
757
- self.backend.bk_concat([offset_tensor, tmpEE], axis=0),
758
- self.backend.bk_concat([offset_tensor, tmpBB], axis=0),
759
- self.backend.bk_concat([offset_tensor, tmpTE], axis=0),
760
- self.backend.bk_concat([offset_tensor, tmpEB], axis=0),
761
- self.backend.bk_concat([offset_tensor, tmpTB], axis=0),
762
- ],
763
- axis=0,
764
- )
776
+ if do_all_pol:
777
+ l_cl.append(offset_tensor)
778
+ l_cl.append(2*tmpTT)
779
+ l_cl.append(offset_tensor)
780
+ l_cl.append(2*tmpEE)
781
+ l_cl.append(offset_tensor)
782
+ l_cl.append(2*tmpBB)
783
+ l_cl.append(offset_tensor)
784
+ l_cl.append(2*tmpTE)
785
+ l_cl.append(offset_tensor)
786
+ l_cl.append(2*tmpEB)
787
+ l_cl.append(offset_tensor)
788
+ l_cl.append(2*tmpTB)
765
789
  else:
766
- l_cl = self.backend.bk_concat(
767
- [
768
- self.backend.bk_concat([offset_tensor, tmpEE], axis=0),
769
- self.backend.bk_concat([offset_tensor, tmpBB], axis=0),
770
- self.backend.bk_concat([offset_tensor, tmpEB], axis=0),
771
- ],
772
- axis=0,
773
- )
774
-
775
- if doT:
776
- l_cl = self.backend.bk_reshape(l_cl, [6, lmax + 1])
777
- else:
778
- l_cl = self.backend.bk_reshape(l_cl, [3, lmax + 1])
790
+ l_cl.append(offset_tensor)
791
+ l_cl.append(2*tmpEE)
792
+ l_cl.append(offset_tensor)
793
+ l_cl.append(2*tmpBB)
794
+ l_cl.append(offset_tensor)
795
+ l_cl.append(2*tmpEB)
779
796
  else:
780
797
  tmp = self.backend.bk_real((tmp * self.backend.bk_conjugate(tmp2)))
781
798
  if m == 0:
782
- l_cl = tmp
799
+ l_cl.append(tmp)
783
800
  else:
784
801
  offset_tensor = self.backend.bk_zeros(
785
- (m), dtype=self.backend.all_bk_type
802
+ (N_image,m), dtype=self.backend.all_bk_type
786
803
  )
787
- l_cl = self.backend.bk_concat([offset_tensor, tmp], axis=0)
788
-
789
- if cl2 is None:
790
- cl2 = l_cl
804
+ l_cl.append(offset_tensor)
805
+ l_cl.append(2*tmp)
806
+
807
+ l_cl=self.backend.bk_concat(l_cl,1)
808
+
809
+ if doT:
810
+ cl2 = self.backend.bk_reshape(l_cl,[N_image,lmax+1,lmax+1])
811
+ cl2 = self.backend.bk_reduce_sum(cl2,1)
812
+ else:
813
+ if do_all_pol:
814
+ cl2 = self.backend.bk_reshape(l_cl,[N_image,lmax+1,6,lmax+1])
791
815
  else:
792
- cl2 += 2 * l_cl
793
-
794
- # cl2=cl2*(4*np.pi) #self.backend.bk_sqrt(self.backend.bk_cast(4*np.pi)) #(2*np.arange(cl2.shape[0])+1)))
795
-
816
+ cl2 = self.backend.bk_reshape(l_cl,[N_image,lmax+1,3,lmax+1])
817
+ cl2 = self.backend.bk_reduce_sum(cl2,1)
818
+
819
+ if no_input_column:
820
+ cl2=cl2[0]
821
+
796
822
  cl2_l1 = self.backend.bk_L1(cl2)
797
823
 
798
824
  return cl2, cl2_l1
foscat/backend.py CHANGED
@@ -927,6 +927,7 @@ class foscat_backend:
927
927
  if self.BACKEND == self.TORCH:
928
928
  if isinstance(data, np.ndarray):
929
929
  return data.reshape(shape)
930
+ return data.view(shape)
930
931
 
931
932
  return self.backend.reshape(data, shape)
932
933
 
@@ -1027,7 +1028,9 @@ class foscat_backend:
1027
1028
 
1028
1029
  def bk_fftn(self, data,dim=None):
1029
1030
  if self.BACKEND == self.TENSORFLOW:
1030
- return self.backend.signal.fftn(data)
1031
+ #Equivalent of torch.fft.fftn(x, dim=dims) in TensorFlow
1032
+ x=self.bk_complex(data,0*data)
1033
+ return self.backend.signal.fftnd(x, fft_length=tuple(x.shape[d] for d in dim),axes=dim)
1031
1034
  if self.BACKEND == self.TORCH:
1032
1035
  return self.backend.fft.fftn(data,dim=dim)
1033
1036
  if self.BACKEND == self.NUMPY:
@@ -1035,7 +1038,7 @@ class foscat_backend:
1035
1038
 
1036
1039
  def bk_ifftn(self, data,dim=None,norm=None):
1037
1040
  if self.BACKEND == self.TENSORFLOW:
1038
- return self.backend.signal.ifftn(data)
1041
+ return self.backend.signal.ifftnd(data,fft_length=tuple(data.shape[d] for d in dim),axes=dim,norm=norm)
1039
1042
  if self.BACKEND == self.TORCH:
1040
1043
  return self.backend.fft.ifftn(data,dim=dim,norm=norm)
1041
1044
  if self.BACKEND == self.NUMPY: