foscat 3.0.46__py3-none-any.whl → 3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -859,6 +859,30 @@ class scat_cov:
859
859
  tab2nx=tab2nx+['%d'%(i2)]
860
860
  ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
861
861
  n=n+j2[j1==i2].shape[0]-1
862
+ elif len(tmp.shape)==3:
863
+ for i0 in range(tmp.shape[0]):
864
+ for i1 in range(tmp.shape[1]):
865
+ for i2 in range(j1.max()+1):
866
+ dtmp=tmp[i0,i1,j1==i2]
867
+ if norm:
868
+ dtmp=dtmp/(ntmp[i0,i1,i2]*ntmp[i0,i1,j2[j1==i2]])
869
+ if j2[j1==i2].shape[0]==1:
870
+ ax1.plot(j2[j1==i2]+n,dtmp,'.', \
871
+ color=color, lw=lw)
872
+ else:
873
+ if legend and test is None:
874
+ ax1.plot(j2[j1==i2]+n,dtmp, \
875
+ color=color, label=lname, lw=lw)
876
+ test=1
877
+ ax1.plot(j2[j1==i2]+n,dtmp, \
878
+ color=color, lw=lw)
879
+ tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
880
+ tabx=tabx+[k+n for k in j2[j1==i2]]
881
+ tab2x=tab2x+[(j2[j1==i2]+n).mean()]
882
+ tab2nx=tab2nx+['%d'%(i2)]
883
+ ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
884
+ n=n+j2[j1==i2].shape[0]-1
885
+
862
886
  else:
863
887
  for i0 in range(tmp.shape[0]):
864
888
  for i1 in range(tmp.shape[1]):
@@ -951,6 +975,32 @@ class scat_cov:
951
975
  tab2x=tab2x+[(n+nprev-1)/2]
952
976
  tab2nx=tab2nx+['%d'%(i2)]
953
977
  ax1.axvline(n-0.5,ls=':',color='gray')
978
+ elif len(tmp.shape)==3:
979
+ for i0 in range(tmp.shape[0]):
980
+ for i1 in range(tmp.shape[1]):
981
+ for i2 in range(j1.max()+1):
982
+ nprev=n
983
+ for i2b in range(j2[j1==i2].max()+1):
984
+ idx=np.where((j1==i2)*(j2==i2b))[0]
985
+ dtmp=tmp[i0,i1,idx]
986
+ if norm:
987
+ dtmp=dtmp/(ntmp[i0,i1,i2]*ntmp[i0,i1,i2b])
988
+ if len(idx)==1:
989
+ ax1.plot(np.arange(len(idx))+n,dtmp,'.', \
990
+ color=color, lw=lw)
991
+ else:
992
+ if legend and test is None:
993
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
994
+ color=color, label=lname, lw=lw)
995
+ test=1
996
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
997
+ color=color, lw=lw)
998
+ tabnx=tabnx+[r'%d,%d'%(j2[k],j3[k]) for k in idx]
999
+ tabx=tabx+[k+n for k in range(len(idx))]
1000
+ n=n+idx.shape[0]
1001
+ tab2x=tab2x+[(n+nprev-1)/2]
1002
+ tab2nx=tab2nx+['%d'%(i2)]
1003
+ ax1.axvline(n-0.5,ls=':',color='gray')
954
1004
  else:
955
1005
  for i0 in range(tmp.shape[0]):
956
1006
  for i1 in range(tmp.shape[1]):
@@ -1495,6 +1545,8 @@ class funct(FOC.FoCUS):
1495
1545
  def fill(self,im,nullval=hp.UNSEEN):
1496
1546
  if self.use_2D:
1497
1547
  return self.fill_2d(im,nullval=nullval)
1548
+ if self.use_1D:
1549
+ return self.fill_1d(im,nullval=nullval)
1498
1550
  return self.fill_healpy(im,nullval=nullval)
1499
1551
 
1500
1552
  def moments(self,list_scat):
@@ -1748,6 +1800,15 @@ class funct(FOC.FoCUS):
1748
1800
  x1=im_shape[1]
1749
1801
  x2=im_shape[2]
1750
1802
  J = int(np.log(nside-self.KERNELSZ) / np.log(2)) # Number of j scales
1803
+ elif self.use_1D:
1804
+ if len(image1.shape)==2:
1805
+ npix = int(im_shape[1]) # Number of pixels
1806
+ else:
1807
+ npix = int(im_shape[0]) # Number of pixels
1808
+
1809
+ nside=int(npix)
1810
+
1811
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
1751
1812
  else:
1752
1813
  if len(image1.shape)==2:
1753
1814
  npix = int(im_shape[1]) # Number of pixels
@@ -1785,6 +1846,11 @@ class funct(FOC.FoCUS):
1785
1846
  I1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1786
1847
  if cross:
1787
1848
  I2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1849
+ elif self.use_1D:
1850
+ vmask=self.up_grade(vmask,I1.shape[axis]*2,axis=1)
1851
+ I1=self.up_grade(I1,I1.shape[axis]*2,axis=axis)
1852
+ if cross:
1853
+ I2=self.up_grade(I2,I2.shape[axis]*2,axis=axis)
1788
1854
  else:
1789
1855
  I1 = self.up_grade(I1, nside * 2, axis=axis)
1790
1856
  vmask = self.up_grade(vmask, nside * 2, axis=1)
@@ -1798,6 +1864,11 @@ class funct(FOC.FoCUS):
1798
1864
  I1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1799
1865
  if cross:
1800
1866
  I2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1867
+ elif self.use_1D:
1868
+ vmask=self.up_grade(vmask,I1.shape[axis]*4,axis=1)
1869
+ I1=self.up_grade(I1,I1.shape[axis]*4,axis=axis)
1870
+ if cross:
1871
+ I2=self.up_grade(I2,I2.shape[axis]*4,axis=axis)
1801
1872
  else:
1802
1873
  I1 = self.up_grade(I1, nside * 4, axis=axis)
1803
1874
  vmask = self.up_grade(vmask, nside * 4, axis=1)
@@ -1811,6 +1882,14 @@ class funct(FOC.FoCUS):
1811
1882
  # Coefficients
1812
1883
  S1, P00, C01, C11, C10 = None, None, None, None, None
1813
1884
 
1885
+ off_P0=-2
1886
+ off_C01=-3
1887
+ off_C11=-4
1888
+ if self.use_1D:
1889
+ off_P0=-1
1890
+ off_C01=-1
1891
+ off_C11=-1
1892
+
1814
1893
  # Dictionaries for C01 computation
1815
1894
  M1_dic = {} # M stands for Module M1 = |I1 * Psi|
1816
1895
  if cross:
@@ -1842,7 +1921,6 @@ class funct(FOC.FoCUS):
1842
1921
  s0 = self.masked_mean(I1,vmask,axis=1)
1843
1922
  else:
1844
1923
  s0 = self.masked_mean(I1-I2,vmask,axis=1)
1845
-
1846
1924
 
1847
1925
  #### COMPUTE S1, P00, C01 and C11
1848
1926
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
@@ -1887,6 +1965,7 @@ class funct(FOC.FoCUS):
1887
1965
  else:
1888
1966
  p00 = self.masked_mean(M1_square, vmask, axis=1,rank=j3)
1889
1967
 
1968
+
1890
1969
  if cond_init_P1_dic:
1891
1970
  # We fill P1_dic with P00 for normalisation of C01 and C11
1892
1971
  P1_dic[j3] = p00 # [Nbatch, Nmask, Norient3]
@@ -1900,13 +1979,13 @@ class funct(FOC.FoCUS):
1900
1979
  if norm == 'auto': # Normalize P00
1901
1980
  p00 /= P1_dic[j3]
1902
1981
  if P00 is None:
1903
- P00 = p00[:, :, None, :] # Add a dimension for NP00
1982
+ P00 = self.backend.bk_expand_dims(p00,off_P0) # Add a dimension for NP00
1904
1983
  if calc_var:
1905
- VP00 = vp00[:, :, None, :] # Add a dimension for NP00
1984
+ VP00 = self.backend.bk_expand_dims(vp00,off_P0) # Add a dimension for NP00
1906
1985
  else:
1907
- P00 = self.backend.bk_concat([P00, p00[:, :, None, :]], axis=2)
1986
+ P00 = self.backend.bk_concat([P00, self.backend.bk_expand_dims(p00,off_P0)], axis=2)
1908
1987
  if calc_var:
1909
- VP00 = self.backend.bk_concat([VP00, vp00[:, :, None, :]], axis=2)
1988
+ VP00 = self.backend.bk_concat([VP00, self.backend.bk_expand_dims(vp00,off_P0)], axis=2)
1910
1989
 
1911
1990
  #### S1_auto computation
1912
1991
  ### Image 1 : S1 = < M1 >_pix
@@ -1929,13 +2008,13 @@ class funct(FOC.FoCUS):
1929
2008
  self.div_norm(s1,(P1_dic[j3]) ** 0.5)
1930
2009
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
1931
2010
  if S1 is None:
1932
- S1 = s1[:, :, None, :] # Add a dimension for NS1
2011
+ S1 = self.backend.bk_expand_dims(s1,off_P0) # Add a dimension for NS1
1933
2012
  if calc_var:
1934
- VS1 = vs1[:, :, None, :] # Add a dimension for NS1
2013
+ VS1 = self.backend.bk_expand_dims(vs1,off_P0) # Add a dimension for NS1
1935
2014
  else:
1936
- S1 = self.backend.bk_concat([S1, s1[:, :, None, :]], axis=2)
2015
+ S1 = self.backend.bk_concat([S1,self.backend.bk_expand_dims(s1,off_P0)], axis=2)
1937
2016
  if calc_var:
1938
- VS1 = self.backend.bk_concat([VS1, vs1[:, :, None, :]], axis=2)
2017
+ VS1 = self.backend.bk_concat([VS1, self.backend.bk_expand_dims(vs1,off_P0)], axis=2)
1939
2018
 
1940
2019
  else: # Cross
1941
2020
  ### Make the convolution I2 * Psi_j3
@@ -1994,13 +2073,13 @@ class funct(FOC.FoCUS):
1994
2073
  p00=self.backend.bk_real(p00)
1995
2074
 
1996
2075
  if P00 is None:
1997
- P00 = p00[:,:,None,:] # Add a dimension for NP00
2076
+ P00 = self.backend.bk_expand_dims(p00,off_P0) # Add a dimension for NP00
1998
2077
  if calc_var:
1999
- VP00 = vp00[:,:,None,:] # Add a dimension for NP00
2078
+ VP00 = self.backend.bk_expand_dims(vp00,off_P0) # Add a dimension for NP00
2000
2079
  else:
2001
- P00 = self.backend.bk_concat([P00, p00[:,:,None,:]], axis=2)
2080
+ P00 = self.backend.bk_concat([P00, self.backend.bk_expand_dims(p00,off_P0)], axis=2)
2002
2081
  if calc_var:
2003
- VP00 = self.backend.bk_concat([VP00, vp00[:,:,None,:]], axis=2)
2082
+ VP00 = self.backend.bk_concat([VP00, self.backend.bk_expand_dims(vp00,off_P0)], axis=2)
2004
2083
 
2005
2084
  #### S1_auto computation
2006
2085
  ### Image 1 : S1 = < M1 >_pix
@@ -2022,14 +2101,15 @@ class funct(FOC.FoCUS):
2022
2101
  self.div_norm(s1,(P1_dic[j3]) ** 0.5)
2023
2102
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2024
2103
  if S1 is None:
2025
- S1 = s1[:, :, None, :] # Add a dimension for NS1
2104
+ S1 = self.backend.bk_expand_dims(s1,off_P0) # Add a dimension for NS1
2026
2105
  if calc_var:
2027
- VS1 = vs1[:, :, None, :] # Add a dimension for NS1
2106
+ VS1 = self.backend.bk_expand_dims(vs1,off_P0) # Add a dimension for NS1
2028
2107
  else:
2029
- S1 = self.backend.bk_concat([S1, s1[:, :, None, :]], axis=2)
2108
+ S1 = self.backend.bk_concat([S1, self.backend.bk_expand_dims(s1,off_P0)], axis=2)
2030
2109
  if calc_var:
2031
- VS1 = self.backend.bk_concat([VS1, vs1[:, :, None, :]], axis=2)
2032
-
2110
+ VS1 = self.backend.bk_concat([VS1,
2111
+ self.backend.bk_expand_dims(vs1,off_P0)], axis=2)
2112
+
2033
2113
  # Initialize dictionaries for |I1*Psi_j| * Psi_j3
2034
2114
  M1convPsi_dic = {}
2035
2115
  if cross:
@@ -2067,19 +2147,19 @@ class funct(FOC.FoCUS):
2067
2147
  else:
2068
2148
  ### Normalize C01 with P00_j [Nbatch, Nmask, Norient_j]
2069
2149
  if norm is not None:
2070
- self.div_norm(c01,(P1_dic[j2][:, :, None, :] *
2071
- P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2150
+ self.div_norm(c01,(self.backend.bk_expand_dims(P1_dic[j2],off_P0) *
2151
+ self.backend.bk_expand_dims(P1_dic[j3],-1)) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2072
2152
 
2073
2153
  ### Store C01 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
2074
2154
  if C01 is None:
2075
- C01 = c01[:,:,None,:,:] # Add a dimension for NC01
2155
+ C01 = self.backend.bk_expand_dims(c01,off_C01) # Add a dimension for NC01
2076
2156
  if calc_var:
2077
- VC01 = vc01[:,:,None,:,:] # Add a dimension for NC01
2157
+ VC01 =self.backend.bk_expand_dims(vc01,off_C01) # Add a dimension for NC01
2078
2158
  else:
2079
- C01 = self.backend.bk_concat([C01, c01[:, :, None, :, :]],
2159
+ C01 = self.backend.bk_concat([C01, self.backend.bk_expand_dims(c01,off_C01)],
2080
2160
  axis=2) # Add a dimension for NC01
2081
2161
  if calc_var:
2082
- VC01 = self.backend.bk_concat([VC01, vc01[:, :, None, :, :]],
2162
+ VC01 = self.backend.bk_concat([VC01, self.backend.bk_expand_dims(vc01,off_C01)],
2083
2163
  axis=2) # Add a dimension for NC01
2084
2164
 
2085
2165
  ### C01_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
@@ -2121,28 +2201,28 @@ class funct(FOC.FoCUS):
2121
2201
  else:
2122
2202
  ### Normalize C01 and C10 with P00_j [Nbatch, Nmask, Norient_j]
2123
2203
  if norm is not None:
2124
- self.div_norm(c01,(P2_dic[j2][:, :, None, :] *
2125
- P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2126
- self.div_norm(c10,(P1_dic[j2][:, :, None, :] *
2127
- P2_dic[j3][:, :, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2]
2204
+ self.div_norm(c01,(self.backend.bk_expand_dims(P2_dic[j2],off_P0) *
2205
+ self.backend.bk_expand_dims(P1_dic[j3],-1)) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2206
+ self.div_norm(c10,(self.backend.bk_expand_dims(P1_dic[j2],off_P0) *
2207
+ self.backend.bk_expand_dims(P2_dic[j3],-1)) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2]
2128
2208
 
2129
2209
  ### Store C01 and C10 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
2130
2210
  if C01 is None:
2131
- C01 = c01[:, :, None, :, :] # Add a dimension for NC01
2211
+ C01 = self.backend.bk_expand_dims(c01,off_C01) # Add a dimension for NC01
2132
2212
  if calc_var:
2133
- VC01 = vc01[:, :, None, :, :] # Add a dimension for NC01
2213
+ VC01 = vself.backend.bk_expand_dims(vc01,off_C01) # Add a dimension for NC01
2134
2214
  else:
2135
- C01 = self.backend.bk_concat([C01,c01[:, :, None, :, :]],axis=2) # Add a dimension for NC01
2215
+ C01 = self.backend.bk_concat([C01, self.backend.bk_expand_dims(c01,off_C01)],axis=2) # Add a dimension for NC01
2136
2216
  if calc_var:
2137
- VC01 = self.backend.bk_concat([VC01,vc01[:, :, None, :, :]],axis=2) # Add a dimension for NC01
2217
+ VC01 =self.backend.bk_concat([VC01, self.backend.bk_expand_dims(vc01,off_C01)],axis=2) # Add a dimension for NC01
2138
2218
  if C10 is None:
2139
- C10 = c10[:, :, None, :, :] # Add a dimension for NC01
2219
+ C10 = self.backend.bk_expand_dims(c10,off_C01) # Add a dimension for NC01
2140
2220
  if calc_var:
2141
- VC10 = vc10[:, :, None, :, :] # Add a dimension for NC01
2221
+ VC10 = self.backend.bk_expand_dims(vc10,off_C01) # Add a dimension for NC01
2142
2222
  else:
2143
- C10 = self.backend.bk_concat([C10,c10[:, :, None, :, :]], axis=2) # Add a dimension for NC01
2223
+ C10 = self.backend.bk_concat([C10, self.backend.bk_expand_dims(c10,off_C01)], axis=2) # Add a dimension for NC01
2144
2224
  if calc_var:
2145
- VC10 = self.backend.bk_concat([VC10,vc10[:, :, None, :, :]], axis=2) # Add a dimension for NC01
2225
+ VC10 = self.backend.bk_concat([VC10, self.backend.bk_expand_dims(vc10,off_C01)], axis=2) # Add a dimension for NC01
2146
2226
 
2147
2227
 
2148
2228
  ##### C11
@@ -2167,18 +2247,18 @@ class funct(FOC.FoCUS):
2167
2247
  else:
2168
2248
  ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
2169
2249
  if norm is not None:
2170
- self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
2171
- P1_dic[j2][:, :, None, :,None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2250
+ self.div_norm(c11,(self.backend.bk_expand_dims(self.backend.bk_expand_dims(P1_dic[j1],off_P0),off_P0) *
2251
+ self.backend.bk_expand_dims(self.backend.bk_expand_dims(P1_dic[j2],off_P0),-1)) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2172
2252
  ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
2173
2253
  if C11 is None:
2174
- C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
2254
+ C11 = self.backend.bk_expand_dims(c11,off_C11) # Add a dimension for NC11
2175
2255
  if calc_var:
2176
- VC11 = vc11[:, :, None, :, :, :] # Add a dimension for NC11
2256
+ VC11 = self.backend.bk_expand_dims(vc11,off_C11) # Add a dimension for NC11
2177
2257
  else:
2178
- C11 = self.backend.bk_concat([C11,c11[:, :, None, :, :, :]],
2258
+ C11 = self.backend.bk_concat([C11,self.backend.bk_expand_dims(c11,off_C11)],
2179
2259
  axis=2) # Add a dimension for NC11
2180
2260
  if calc_var:
2181
- VC11 = self.backend.bk_concat([VC11,vc11[:, :, None, :, :, :]],
2261
+ VC11 = self.backend.bk_concat([VC11,self.backend.bk_expand_dims(vc11,off_C11)],
2182
2262
  axis=2) # Add a dimension for NC11
2183
2263
 
2184
2264
  ### C11_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
@@ -2201,18 +2281,18 @@ class funct(FOC.FoCUS):
2201
2281
  else:
2202
2282
  ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
2203
2283
  if norm is not None:
2204
- self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
2205
- P2_dic[j2][:, :, None, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2284
+ self.div_norm(c11,(self.backend.bk_expand_dims(self.backend.bk_expand_dims(P1_dic[j1],off_P0),off_P0) *
2285
+ self.backend.bk_expand_dims(self.backend.bk_expand_dims(P2_dic[j2],off_P0),-1)) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2206
2286
  ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
2207
2287
  if C11 is None:
2208
- C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
2288
+ C11 = self.backend.bk_expand_dims(c11,off_C11) # Add a dimension for NC11
2209
2289
  if calc_var:
2210
- VC11 = vc11[:, :, None, :, :, :] # Add a dimension for NC11
2290
+ VC11 = self.backend.bk_expand_dims(vc11,off_C11) # Add a dimension for NC11
2211
2291
  else:
2212
- C11 = self.backend.bk_concat([C11,c11[:, :, None, :, :, :]],
2292
+ C11 = self.backend.bk_concat([C11,self.backend.bk_expand_dims(c11,off_C11)],
2213
2293
  axis=2) # Add a dimension for NC11
2214
2294
  if calc_var:
2215
- VC11 = self.backend.bk_concat([VC11,vc11[:, :, None, :, :, :]],
2295
+ VC11 = self.backend.bk_concat([VC11,self.backend.bk_expand_dims(vc11,off_C11)],
2216
2296
  axis=2) # Add a dimension for NC11
2217
2297
 
2218
2298
  ###### Reshape for next iteration on j3
@@ -2298,7 +2378,10 @@ class funct(FOC.FoCUS):
2298
2378
  ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
2299
2379
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
2300
2380
  # cconv, sconv are [Nbatch, Npix_j3, Norient3]
2301
- c01 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(MconvPsi) # [Nbatch, Npix_j3, Norient3, Norient2]
2381
+ if self.use_1D:
2382
+ c01 = conv * self.backend.bk_conjugate(MconvPsi)
2383
+ else:
2384
+ c01 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(MconvPsi) # [Nbatch, Npix_j3, Norient3, Norient2]
2302
2385
 
2303
2386
  ### Apply the mask [Nmask, Npix_j3] and sum over pixels
2304
2387
  if return_data:
@@ -2327,7 +2410,10 @@ class funct(FOC.FoCUS):
2327
2410
 
2328
2411
  ### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
2329
2412
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
2330
- c11 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(self.backend.bk_expand_dims(M2, -1)) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
2413
+ if self.use_1D:
2414
+ c11 = M1 * self.backend.bk_conjugate(M2)
2415
+ else:
2416
+ c11 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(self.backend.bk_expand_dims(M2, -1)) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
2331
2417
 
2332
2418
  ### Apply the mask and sum over pixels
2333
2419
  if return_data: