foscat 3.0.47__tar.gz → 3.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {foscat-3.0.47 → foscat-3.1.1}/PKG-INFO +1 -1
  2. {foscat-3.0.47 → foscat-3.1.1}/setup.py +1 -1
  3. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/FoCUS.py +292 -47
  4. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/backend.py +24 -1
  5. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/scat_cov.py +136 -50
  6. foscat-3.0.47/src/foscat/scat_cov1D.py → foscat-3.1.1/src/foscat/scat_cov1D.old.py +4 -2
  7. foscat-3.1.1/src/foscat/scat_cov1D.py +16 -0
  8. {foscat-3.0.47 → foscat-3.1.1}/src/foscat.egg-info/PKG-INFO +1 -1
  9. {foscat-3.0.47 → foscat-3.1.1}/src/foscat.egg-info/SOURCES.txt +1 -0
  10. {foscat-3.0.47 → foscat-3.1.1}/README.md +0 -0
  11. {foscat-3.0.47 → foscat-3.1.1}/setup.cfg +0 -0
  12. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/CNN.py +0 -0
  13. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/CircSpline.py +0 -0
  14. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/GCNN.py +0 -0
  15. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/GetGPUinfo.py +0 -0
  16. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/Softmax.py +0 -0
  17. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/Spline1D.py +0 -0
  18. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/Synthesis.py +0 -0
  19. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/__init__.py +0 -0
  20. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/backend_tens.py +0 -0
  21. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/loss_backend_tens.py +0 -0
  22. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/loss_backend_torch.py +0 -0
  23. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/scat.py +0 -0
  24. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/scat1D.py +0 -0
  25. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/scat2D.py +0 -0
  26. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/scat_cov2D.py +0 -0
  27. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/scat_cov_map.py +0 -0
  28. {foscat-3.0.47 → foscat-3.1.1}/src/foscat/scat_cov_map2D.py +0 -0
  29. {foscat-3.0.47 → foscat-3.1.1}/src/foscat.egg-info/dependency_links.txt +0 -0
  30. {foscat-3.0.47 → foscat-3.1.1}/src/foscat.egg-info/requires.txt +0 -0
  31. {foscat-3.0.47 → foscat-3.1.1}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.47
3
+ Version: 3.1.1
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
 
4
4
  setup(
5
5
  name='foscat',
6
- version='3.0.47',
6
+ version='3.1.1',
7
7
  description='Generate synthetic Healpix or 2D data using Cross Scattering Transform' ,
8
8
  long_description='Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising. \n A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO. \n Complete doc can be found at https://foscat-documentation.readthedocs.io/en/latest/index.html. \n\n List of developers : J.-M. Delouis, T. Foulquier, L. Mousset, T. Odaka, F. Paul, E. Allys ' ,
9
9
  license='MIT',
@@ -24,6 +24,7 @@ class FoCUS:
24
24
  TEMPLATE_PATH='data',
25
25
  BACKEND='tensorflow',
26
26
  use_2D=False,
27
+ use_1D=False,
27
28
  return_data=False,
28
29
  JmaxDelta=0,
29
30
  DODIV=False,
@@ -32,7 +33,7 @@ class FoCUS:
32
33
  mpi_size=1,
33
34
  mpi_rank=0):
34
35
 
35
- self.__version__ = '3.0.47'
36
+ self.__version__ = '3.1.1'
36
37
  # P00 coeff for normalization for scat_cov
37
38
  self.TMPFILE_VERSION=TMPFILE_VERSION
38
39
  self.P1_dic = None
@@ -86,6 +87,7 @@ class FoCUS:
86
87
 
87
88
  self.OSTEP=JmaxDelta
88
89
  self.use_2D=use_2D
90
+ self.use_1D=use_1D
89
91
 
90
92
  if isMPI:
91
93
  from mpi4py import MPI
@@ -214,12 +216,14 @@ class FoCUS:
214
216
 
215
217
 
216
218
  w_smooth=w_smooth.flatten()
217
-
219
+ if self.use_1D:
220
+ KERNELSZ=5
221
+
218
222
  self.KERNELSZ=KERNELSZ
219
223
 
220
224
  self.Idx_Neighbours={}
221
225
 
222
- if not self.use_2D:
226
+ if not self.use_2D and not self.use_1D:
223
227
  self.w_smooth = {}
224
228
  for i in range(nstep_max):
225
229
  lout=(2**i)
@@ -239,6 +243,27 @@ class FoCUS:
239
243
  self.ww_Real[lout]=wr
240
244
  self.ww_Imag[lout]=wi
241
245
  self.w_smooth[lout]=ws
246
+ elif self.use_1D==True:
247
+ self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
248
+ self.ww_RealT={}
249
+ self.ww_ImagT={}
250
+ self.ww_SmoothT={}
251
+ if KERNELSZ==5:
252
+ xx=np.arange(5)-2
253
+ w=np.exp(-0.25*(xx)**2)
254
+ c=w*np.cos((xx)*np.pi/2)
255
+ s=w*np.sin((xx)*np.pi/2)
256
+
257
+ w=w/np.sum(w)
258
+ c=c-np.mean(c)
259
+ s=s-np.mean(s)
260
+ r=np.sum(np.sqrt(c*c+s*s))
261
+ c=c/r
262
+ s=s/r
263
+ self.ww_RealT[1]=self.backend.constant(np.array(c).reshape(xx.shape[0],1,1))
264
+ self.ww_ImagT[1]=self.backend.constant(np.array(s).reshape(xx.shape[0],1,1))
265
+ self.ww_SmoothT[1] = self.backend.constant(np.array(w).reshape(xx.shape[0],1,1))
266
+
242
267
  else:
243
268
  self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
244
269
  self.ww_RealT={}
@@ -573,7 +598,7 @@ class FoCUS:
573
598
  tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
574
599
  tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),0:2*(npiy//2),:],[ndata,npix//2,2,npiy//2,2,odata])
575
600
 
576
- res=self.backend.bk_reduce_mean(self.backend.bk_reduce_mean(tim,4),2)
601
+ res=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim,4),2)/4
577
602
 
578
603
  if axis==0:
579
604
  if len(ishape)==2:
@@ -587,6 +612,40 @@ class FoCUS:
587
612
  return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2]+ishape[axis+2:])
588
613
 
589
614
  return self.backend.bk_reshape(res,[npix//2,npiy//2])
615
+ elif self.use_1D:
616
+ ishape=list(im.shape)
617
+ if len(ishape)<axis+1:
618
+ if not self.silent:
619
+ print('Use of 1D scat with data that has less than 1D')
620
+ return None
621
+
622
+ npix=im.shape[axis]
623
+ odata=1
624
+ if len(ishape)>axis+1:
625
+ for k in range(axis+1,len(ishape)):
626
+ odata=odata*ishape[k]
627
+
628
+ ndata=1
629
+ for k in range(axis):
630
+ ndata=ndata*ishape[k]
631
+
632
+ tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
633
+ tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),:],[ndata,npix//2,2,odata])
634
+
635
+ res=self.backend.bk_reduce_mean(tim,2)
636
+
637
+ if axis==0:
638
+ if len(ishape)==1:
639
+ return self.backend.bk_reshape(res,[npix//2])
640
+ else:
641
+ return self.backend.bk_reshape(res,[npix//2]+ishape[axis+1:])
642
+ else:
643
+ if len(ishape)==axis+1:
644
+ return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2])
645
+ else:
646
+ return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2]+ishape[axis+1:])
647
+
648
+ return self.backend.bk_reshape(res,[npix//2])
590
649
 
591
650
  else:
592
651
  shape=list(im.shape)
@@ -653,6 +712,46 @@ class FoCUS:
653
712
  return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty]+ishape[axis+2:])
654
713
 
655
714
  return self.backend.bk_reshape(res,[nout,nouty])
715
+
716
+ elif self.use_1D:
717
+ ishape=list(im.shape)
718
+ if len(ishape)<axis+1:
719
+ if not self.silent:
720
+ print('Use of 1D scat with data that has less than 1D')
721
+ return None
722
+
723
+ if ishape[axis]==nout:
724
+ return im
725
+
726
+ npix=im.shape[axis]
727
+ odata=1
728
+ if len(ishape)>axis+1:
729
+ for k in range(axis+1,len(ishape)):
730
+ odata=odata*ishape[k]
731
+
732
+ ndata=1
733
+ for k in range(axis):
734
+ ndata=ndata*ishape[k]
735
+
736
+ tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
737
+
738
+ while tim.shape[1]!=nout:
739
+ res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
740
+ res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
741
+ tim = self.backend.bk_reshape(self.backend.bk_concat([res1,res2],-2),[ndata,tim.shape[1]*2,odata])
742
+
743
+ if axis==0:
744
+ if len(ishape)==1:
745
+ return self.backend.bk_reshape(tim,[nout])
746
+ else:
747
+ return self.backend.bk_reshape(tim,[nout]+ishape[axis+1:])
748
+ else:
749
+ if len(ishape)==axis+1:
750
+ return self.backend.bk_reshape(tim,ishape[0:axis]+[nout])
751
+ else:
752
+ return self.backend.bk_reshape(tim,ishape[0:axis]+[nout]+ishape[axis+1:])
753
+
754
+ return self.backend.bk_reshape(tim,[nout])
656
755
 
657
756
  else:
658
757
 
@@ -842,7 +941,7 @@ class FoCUS:
842
941
  ndata=ndata*ishape[k]
843
942
 
844
943
  tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
845
-
944
+
846
945
  res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
847
946
  res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
848
947
  res = self.backend.bk_concat([res1,res2],-2)
@@ -1278,66 +1377,117 @@ class FoCUS:
1278
1377
  sum_mask=self.backend.bk_reduce_sum(self.backend.bk_reshape(l_mask,[l_mask.shape[0],np.prod(np.array(l_mask.shape[1:]))]),1)
1279
1378
  if not self.use_2D:
1280
1379
  l_mask=12*nside*nside*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1281
- else:
1380
+ elif self.use_2D:
1282
1381
  l_mask=mask.shape[1]*mask.shape[2]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1283
-
1382
+ else:
1383
+ l_mask=mask.shape[1]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1384
+
1284
1385
  if self.use_2D:
1285
- if self.padding=='VALID' and shape[axis]!=l_mask.shape[1]:
1386
+ if self.padding=='VALID':
1286
1387
  l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1287
1388
  if shape[axis]!=l_mask.shape[1]:
1288
1389
  l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1289
- else:
1290
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1291
1390
 
1292
- # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,...,KERNELSZ//2:-self.KERNELSZ//2,KERNELSZ//2:-self.KERNELSZ//2,NORIENT[,NORIENT]]
1293
- if self.use_2D:
1294
1391
  ichannel=1
1295
1392
  for i in range(axis):
1296
1393
  ichannel*=shape[i]
1297
1394
  ochannel=1
1298
1395
  for i in range(axis+2,len(shape)):
1299
1396
  ochannel*=shape[i]
1300
- l_x=self.backend.bk_reshape(x,[ichannel,shape[axis],shape[axis+1],ochannel])
1301
- oshape=[k for k in shape]
1302
- oshape[axis]=oshape[axis]-self.KERNELSZ+1
1303
- oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
1304
- l_x=self.backend.bk_reshape(l_x[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1305
- else:
1306
- l_x=x
1397
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],shape[axis+1],ochannel])
1398
+
1399
+ if self.padding=='VALID':
1400
+ oshape=[k for k in shape]
1401
+ oshape[axis]=oshape[axis]-self.KERNELSZ+1
1402
+ oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
1403
+ l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1404
+
1405
+ elif self.use_1D:
1406
+ if self.padding=='VALID':
1407
+ l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1408
+ if shape[axis]!=l_mask.shape[1]:
1409
+ l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1410
+
1411
+ ichannel=1
1412
+ for i in range(axis):
1413
+ ichannel*=shape[i]
1414
+ ochannel=1
1415
+ for i in range(axis+1,len(shape)):
1416
+ ochannel*=shape[i]
1417
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1307
1418
 
1419
+ if self.padding=='VALID':
1420
+ oshape=[k for k in shape]
1421
+ oshape[axis]=oshape[axis]-self.KERNELSZ+1
1422
+ l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1423
+ else:
1424
+ ichannel=1
1425
+ for i in range(axis):
1426
+ ichannel*=shape[i]
1427
+ ochannel=1
1428
+ for i in range(axis+1,len(shape)):
1429
+ ochannel*=shape[i]
1430
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1431
+
1308
1432
  # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,1,...,X[,Y],NORIENT[,NORIENT]]
1309
- l_x=self.backend.bk_expand_dims(l_x,1)
1310
-
1311
1433
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,X[,Y]]
1312
1434
  l_mask=self.backend.bk_expand_dims(l_mask,0)
1313
-
1314
- # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
1315
- for i in range(1,axis):
1316
- l_mask=self.backend.bk_expand_dims(l_mask,axis)
1435
+ # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,X[,Y],1]
1436
+ l_mask=self.backend.bk_expand_dims(l_mask,-1)
1317
1437
 
1318
1438
  if l_x.dtype==self.all_cbk_type:
1319
1439
  l_mask=self.backend.bk_complex(l_mask,self.backend.bk_cast(0.0*l_mask))
1320
1440
 
1321
1441
  if self.use_2D:
1442
+ mtmp=l_mask
1443
+ vtmp=l_x
1444
+
1445
+ v1=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp,axis=2),2)
1446
+ v2=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2),2)
1447
+ vh=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp,axis=2),2)
1322
1448
 
1323
- # mask=[1,Nmask,....,X,Y] => mask=[1,Nmask,....,X,Y,....]
1324
- for i in range(axis+2,len(x.shape)):
1325
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1326
-
1327
- shape1=list(l_mask.shape)
1328
- shape2=list(l_x.shape)
1449
+ res=v1/vh
1450
+
1451
+ oshape=[]
1452
+ if axis>0:
1453
+ oshape=oshape+list(x.shape[0:axis])
1454
+ oshape=oshape+[mask.shape[0]]
1455
+ if axis+1<len(x.shape):
1456
+ oshape=oshape+list(x.shape[axis+2:])
1457
+
1458
+ if calc_var:
1459
+ if self.backend.bk_is_complex(vtmp):
1460
+ res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1461
+ -self.backend.bk_real(res)*self.backend.bk_real(res)) + \
1462
+ (self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
1463
+ -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1464
+ else:
1465
+ res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1329
1466
 
1330
- oshape1=shape1[0:axis+1]+[shape1[axis+1]*shape1[axis+2]]+shape1[axis+3:]
1331
- oshape2=shape2[0:axis+1]+[shape2[axis+1]*shape2[axis+2]]+shape2[axis+3:]
1467
+ res=self.backend.bk_reshape(res,oshape)
1468
+ res2=self.backend.bk_reshape(res2,oshape)
1469
+ return res,res2
1470
+ else:
1471
+ res=self.backend.bk_reshape(res,oshape)
1472
+ return res
1332
1473
 
1333
- mtmp=self.backend.bk_reshape(l_mask,oshape1)
1334
- vtmp=self.backend.bk_reshape(l_x,oshape2)
1474
+ elif self.use_1D:
1475
+ mtmp=l_mask
1476
+ vtmp=l_x
1335
1477
 
1336
- v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=axis+1)
1337
- v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=axis+1)
1338
- vh=self.backend.bk_reduce_sum(mtmp,axis=axis+1)
1478
+ v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=2)
1479
+ v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2)
1480
+ vh=self.backend.bk_reduce_sum(mtmp,axis=2)
1339
1481
 
1340
1482
  res=v1/vh
1483
+
1484
+ oshape=[]
1485
+ if axis>0:
1486
+ oshape=oshape+list(x.shape[0:axis])
1487
+ oshape=oshape+[mask.shape[0]]
1488
+ if axis+1<len(x.shape):
1489
+ oshape=oshape+list(x.shape[axis+1:])
1490
+
1341
1491
  if calc_var:
1342
1492
  if self.backend.bk_is_complex(vtmp):
1343
1493
  res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
@@ -1346,19 +1496,29 @@ class FoCUS:
1346
1496
  -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1347
1497
  else:
1348
1498
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1499
+
1500
+
1501
+ res=self.backend.bk_reshape(res,oshape)
1502
+ res2=self.backend.bk_reshape(res2,oshape)
1349
1503
  return res,res2
1350
1504
  else:
1505
+ res=self.backend.bk_reshape(res,oshape)
1351
1506
  return res
1352
- else:
1353
- # mask=[1,Nmask,....,X] => mask=[1,Nmask,....,X,....]
1354
- for i in range(axis+1,len(x.shape)):
1355
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1356
-
1357
- v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=axis+1)
1358
- v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=axis+1)
1359
- vh=self.backend.bk_reduce_sum(l_mask,axis=axis+1)
1507
+
1508
+ else:
1509
+ v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=2)
1510
+ v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=2)
1511
+ vh=self.backend.bk_reduce_sum(l_mask,axis=2)
1360
1512
 
1361
1513
  res=v1/vh
1514
+
1515
+ oshape=[]
1516
+ if axis>0:
1517
+ oshape=oshape+list(x.shape[0:axis])
1518
+ oshape=oshape+[mask.shape[0]]
1519
+ if axis+1<len(x.shape):
1520
+ oshape=oshape+list(x.shape[axis+1:])
1521
+
1362
1522
  if calc_var:
1363
1523
  if self.backend.bk_is_complex(l_x):
1364
1524
  res2=self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
@@ -1367,8 +1527,12 @@ class FoCUS:
1367
1527
  -self.backend.bk_imag(res)*self.backend.bk_imag(res))/self.backend.bk_real(vh))
1368
1528
  else:
1369
1529
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1530
+
1531
+ res=self.backend.bk_reshape(res,oshape)
1532
+ res2=self.backend.bk_reshape(res2,oshape)
1370
1533
  return res,res2
1371
1534
  else:
1535
+ res=self.backend.bk_reshape(res,oshape)
1372
1536
  return res
1373
1537
 
1374
1538
  # ---------------------------------------------−---------
@@ -1485,7 +1649,6 @@ class FoCUS:
1485
1649
  image=self.backend.bk_cast(in_image)
1486
1650
 
1487
1651
  if self.use_2D:
1488
-
1489
1652
  ishape=list(in_image.shape)
1490
1653
  if len(ishape)<axis+2:
1491
1654
  if not self.silent:
@@ -1528,6 +1691,49 @@ class FoCUS:
1528
1691
  return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
1529
1692
 
1530
1693
  return self.backend.bk_reshape(res,[nout,nouty])
1694
+ elif self.use_1D==True:
1695
+ ishape=list(in_image.shape)
1696
+ if len(ishape)<axis+1:
1697
+ if not self.silent:
1698
+ print('Use of 1D scat with data that has less than 1D')
1699
+ return None
1700
+
1701
+ npix=ishape[axis]
1702
+ odata=1
1703
+ if len(ishape)>axis+1:
1704
+ for k in range(axis+1,len(ishape)):
1705
+ odata=odata*ishape[k]
1706
+
1707
+ ndata=1
1708
+ for k in range(axis):
1709
+ ndata=ndata*ishape[k]
1710
+
1711
+ tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
1712
+
1713
+ if self.backend.bk_is_complex(tim):
1714
+ rr1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1715
+ ii1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1716
+ rr2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1717
+ ii2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1718
+ res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
1719
+ else:
1720
+ rr=self.backend.conv1d(tim,self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1721
+ ii=self.backend.conv1d(tim,self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1722
+ res=self.backend.bk_complex(rr,ii)
1723
+
1724
+ if axis==0:
1725
+ if len(ishape)==1:
1726
+ return self.backend.bk_reshape(res,[res.shape[1]])
1727
+ else:
1728
+ return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+2:])
1729
+ else:
1730
+ if len(ishape)==axis+1:
1731
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
1732
+ else:
1733
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1734
+
1735
+ return self.backend.bk_reshape(res,[nout,nouty])
1736
+
1531
1737
 
1532
1738
  else:
1533
1739
  nside=int(np.sqrt(image.shape[axis]//12))
@@ -1649,7 +1855,46 @@ class FoCUS:
1649
1855
  return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]]+ishape[axis+2:])
1650
1856
 
1651
1857
  return self.backend.bk_reshape(res,[nout,nouty])
1858
+ elif self.use_1D:
1859
+
1860
+ ishape=list(in_image.shape)
1861
+ if len(ishape)<axis+1:
1862
+ if not self.silent:
1863
+ print('Use of 1D scat with data that has less than 1D')
1864
+ return None
1652
1865
 
1866
+ npix=ishape[axis]
1867
+ odata=1
1868
+ if len(ishape)>axis+1:
1869
+ for k in range(axis+1,len(ishape)):
1870
+ odata=odata*ishape[k]
1871
+
1872
+ ndata=1
1873
+ for k in range(axis):
1874
+ ndata=ndata*ishape[k]
1875
+
1876
+ tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
1877
+
1878
+ if self.backend.bk_is_complex(tim):
1879
+ rr=self.backend.conv1d(self.backend.bk_real(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1880
+ ii=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1881
+ res=self.backend.bk_complex(rr,ii)
1882
+ else:
1883
+ res=self.backend.conv1d(tim,self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1884
+
1885
+ if axis==0:
1886
+ if len(ishape)==1:
1887
+ return self.backend.bk_reshape(res,[res.shape[1]])
1888
+ else:
1889
+ return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+1:])
1890
+ else:
1891
+ if len(ishape)==axis+1:
1892
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
1893
+ else:
1894
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1895
+
1896
+ return self.backend.bk_reshape(res,[nout,nouty])
1897
+
1653
1898
  else:
1654
1899
  nside=int(np.sqrt(image.shape[axis]//12))
1655
1900
 
@@ -289,10 +289,33 @@ class foscat_backend:
289
289
  for k in range(w.shape[2]):
290
290
  for l in range(w.shape[3]):
291
291
  for j in range(res.shape[0]):
292
- tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='fill', fillvalue=0.0)
292
+ tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='symm')
293
293
  res[j,:,:,l]+=tmp
294
294
  del tmp
295
295
  return res
296
+
297
+ def conv1d(self,x,w,strides=[1, 1, 1],padding='SAME'):
298
+ if self.BACKEND==self.TENSORFLOW:
299
+ kx=w.shape[0]
300
+ paddings = self.backend.constant([[0,0],
301
+ [kx//2,kx//2],
302
+ [0,0]])
303
+ tmp=self.backend.pad(x, paddings, "SYMMETRIC")
304
+
305
+ return self.backend.nn.conv1d(tmp,w,
306
+ stride=strides,
307
+ padding="VALID")
308
+ # to be written!!!
309
+ if self.BACKEND==self.TORCH:
310
+ return x
311
+ if self.BACKEND==self.NUMPY:
312
+ res=np.zeros([x.shape[0],x.shape[1],w.shape[2]],dtype=x.dtype)
313
+ for k in range(w.shape[2]):
314
+ for j in range(res.shape[0]):
315
+ tmp=self.scipy.signal.convolve1d(x[j,:,k],w[:,k,l], mode='same', boundary='symm')
316
+ res[j,:,:,l]+=tmp
317
+ del tmp
318
+ return res
296
319
 
297
320
  def bk_threshold(self,x,threshold,greater=True):
298
321
 
@@ -859,6 +859,30 @@ class scat_cov:
859
859
  tab2nx=tab2nx+['%d'%(i2)]
860
860
  ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
861
861
  n=n+j2[j1==i2].shape[0]-1
862
+ elif len(tmp.shape)==3:
863
+ for i0 in range(tmp.shape[0]):
864
+ for i1 in range(tmp.shape[1]):
865
+ for i2 in range(j1.max()+1):
866
+ dtmp=tmp[i0,i1,j1==i2]
867
+ if norm:
868
+ dtmp=dtmp/(ntmp[i0,i1,i2]*ntmp[i0,i1,j2[j1==i2]])
869
+ if j2[j1==i2].shape[0]==1:
870
+ ax1.plot(j2[j1==i2]+n,dtmp,'.', \
871
+ color=color, lw=lw)
872
+ else:
873
+ if legend and test is None:
874
+ ax1.plot(j2[j1==i2]+n,dtmp, \
875
+ color=color, label=lname, lw=lw)
876
+ test=1
877
+ ax1.plot(j2[j1==i2]+n,dtmp, \
878
+ color=color, lw=lw)
879
+ tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
880
+ tabx=tabx+[k+n for k in j2[j1==i2]]
881
+ tab2x=tab2x+[(j2[j1==i2]+n).mean()]
882
+ tab2nx=tab2nx+['%d'%(i2)]
883
+ ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
884
+ n=n+j2[j1==i2].shape[0]-1
885
+
862
886
  else:
863
887
  for i0 in range(tmp.shape[0]):
864
888
  for i1 in range(tmp.shape[1]):
@@ -951,6 +975,32 @@ class scat_cov:
951
975
  tab2x=tab2x+[(n+nprev-1)/2]
952
976
  tab2nx=tab2nx+['%d'%(i2)]
953
977
  ax1.axvline(n-0.5,ls=':',color='gray')
978
+ elif len(tmp.shape)==3:
979
+ for i0 in range(tmp.shape[0]):
980
+ for i1 in range(tmp.shape[1]):
981
+ for i2 in range(j1.max()+1):
982
+ nprev=n
983
+ for i2b in range(j2[j1==i2].max()+1):
984
+ idx=np.where((j1==i2)*(j2==i2b))[0]
985
+ dtmp=tmp[i0,i1,idx]
986
+ if norm:
987
+ dtmp=dtmp/(ntmp[i0,i1,i2]*ntmp[i0,i1,i2b])
988
+ if len(idx)==1:
989
+ ax1.plot(np.arange(len(idx))+n,dtmp,'.', \
990
+ color=color, lw=lw)
991
+ else:
992
+ if legend and test is None:
993
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
994
+ color=color, label=lname, lw=lw)
995
+ test=1
996
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
997
+ color=color, lw=lw)
998
+ tabnx=tabnx+[r'%d,%d'%(j2[k],j3[k]) for k in idx]
999
+ tabx=tabx+[k+n for k in range(len(idx))]
1000
+ n=n+idx.shape[0]
1001
+ tab2x=tab2x+[(n+nprev-1)/2]
1002
+ tab2nx=tab2nx+['%d'%(i2)]
1003
+ ax1.axvline(n-0.5,ls=':',color='gray')
954
1004
  else:
955
1005
  for i0 in range(tmp.shape[0]):
956
1006
  for i1 in range(tmp.shape[1]):
@@ -1495,6 +1545,8 @@ class funct(FOC.FoCUS):
1495
1545
  def fill(self,im,nullval=hp.UNSEEN):
1496
1546
  if self.use_2D:
1497
1547
  return self.fill_2d(im,nullval=nullval)
1548
+ if self.use_1D:
1549
+ return self.fill_1d(im,nullval=nullval)
1498
1550
  return self.fill_healpy(im,nullval=nullval)
1499
1551
 
1500
1552
  def moments(self,list_scat):
@@ -1748,6 +1800,15 @@ class funct(FOC.FoCUS):
1748
1800
  x1=im_shape[1]
1749
1801
  x2=im_shape[2]
1750
1802
  J = int(np.log(nside-self.KERNELSZ) / np.log(2)) # Number of j scales
1803
+ elif self.use_1D:
1804
+ if len(image1.shape)==2:
1805
+ npix = int(im_shape[1]) # Number of pixels
1806
+ else:
1807
+ npix = int(im_shape[0]) # Number of pixels
1808
+
1809
+ nside=int(npix)
1810
+
1811
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
1751
1812
  else:
1752
1813
  if len(image1.shape)==2:
1753
1814
  npix = int(im_shape[1]) # Number of pixels
@@ -1785,6 +1846,11 @@ class funct(FOC.FoCUS):
1785
1846
  I1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1786
1847
  if cross:
1787
1848
  I2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1849
+ elif self.use_1D:
1850
+ vmask=self.up_grade(vmask,I1.shape[axis]*2,axis=1)
1851
+ I1=self.up_grade(I1,I1.shape[axis]*2,axis=axis)
1852
+ if cross:
1853
+ I2=self.up_grade(I2,I2.shape[axis]*2,axis=axis)
1788
1854
  else:
1789
1855
  I1 = self.up_grade(I1, nside * 2, axis=axis)
1790
1856
  vmask = self.up_grade(vmask, nside * 2, axis=1)
@@ -1798,6 +1864,11 @@ class funct(FOC.FoCUS):
1798
1864
  I1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1799
1865
  if cross:
1800
1866
  I2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1867
+ elif self.use_1D:
1868
+ vmask=self.up_grade(vmask,I1.shape[axis]*4,axis=1)
1869
+ I1=self.up_grade(I1,I1.shape[axis]*4,axis=axis)
1870
+ if cross:
1871
+ I2=self.up_grade(I2,I2.shape[axis]*4,axis=axis)
1801
1872
  else:
1802
1873
  I1 = self.up_grade(I1, nside * 4, axis=axis)
1803
1874
  vmask = self.up_grade(vmask, nside * 4, axis=1)
@@ -1811,6 +1882,14 @@ class funct(FOC.FoCUS):
1811
1882
  # Coefficients
1812
1883
  S1, P00, C01, C11, C10 = None, None, None, None, None
1813
1884
 
1885
+ off_P0=-2
1886
+ off_C01=-3
1887
+ off_C11=-4
1888
+ if self.use_1D:
1889
+ off_P0=-1
1890
+ off_C01=-1
1891
+ off_C11=-1
1892
+
1814
1893
  # Dictionaries for C01 computation
1815
1894
  M1_dic = {} # M stands for Module M1 = |I1 * Psi|
1816
1895
  if cross:
@@ -1842,7 +1921,6 @@ class funct(FOC.FoCUS):
1842
1921
  s0 = self.masked_mean(I1,vmask,axis=1)
1843
1922
  else:
1844
1923
  s0 = self.masked_mean(I1-I2,vmask,axis=1)
1845
-
1846
1924
 
1847
1925
  #### COMPUTE S1, P00, C01 and C11
1848
1926
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
@@ -1887,6 +1965,7 @@ class funct(FOC.FoCUS):
1887
1965
  else:
1888
1966
  p00 = self.masked_mean(M1_square, vmask, axis=1,rank=j3)
1889
1967
 
1968
+
1890
1969
  if cond_init_P1_dic:
1891
1970
  # We fill P1_dic with P00 for normalisation of C01 and C11
1892
1971
  P1_dic[j3] = p00 # [Nbatch, Nmask, Norient3]
@@ -1900,13 +1979,13 @@ class funct(FOC.FoCUS):
1900
1979
  if norm == 'auto': # Normalize P00
1901
1980
  p00 /= P1_dic[j3]
1902
1981
  if P00 is None:
1903
- P00 = p00[:, :, None, :] # Add a dimension for NP00
1982
+ P00 = self.backend.bk_expand_dims(p00,off_P0) # Add a dimension for NP00
1904
1983
  if calc_var:
1905
- VP00 = vp00[:, :, None, :] # Add a dimension for NP00
1984
+ VP00 = self.backend.bk_expand_dims(vp00,off_P0) # Add a dimension for NP00
1906
1985
  else:
1907
- P00 = self.backend.bk_concat([P00, p00[:, :, None, :]], axis=2)
1986
+ P00 = self.backend.bk_concat([P00, self.backend.bk_expand_dims(p00,off_P0)], axis=2)
1908
1987
  if calc_var:
1909
- VP00 = self.backend.bk_concat([VP00, vp00[:, :, None, :]], axis=2)
1988
+ VP00 = self.backend.bk_concat([VP00, self.backend.bk_expand_dims(vp00,off_P0)], axis=2)
1910
1989
 
1911
1990
  #### S1_auto computation
1912
1991
  ### Image 1 : S1 = < M1 >_pix
@@ -1929,13 +2008,13 @@ class funct(FOC.FoCUS):
1929
2008
  self.div_norm(s1,(P1_dic[j3]) ** 0.5)
1930
2009
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
1931
2010
  if S1 is None:
1932
- S1 = s1[:, :, None, :] # Add a dimension for NS1
2011
+ S1 = self.backend.bk_expand_dims(s1,off_P0) # Add a dimension for NS1
1933
2012
  if calc_var:
1934
- VS1 = vs1[:, :, None, :] # Add a dimension for NS1
2013
+ VS1 = self.backend.bk_expand_dims(vs1,off_P0) # Add a dimension for NS1
1935
2014
  else:
1936
- S1 = self.backend.bk_concat([S1, s1[:, :, None, :]], axis=2)
2015
+ S1 = self.backend.bk_concat([S1,self.backend.bk_expand_dims(s1,off_P0)], axis=2)
1937
2016
  if calc_var:
1938
- VS1 = self.backend.bk_concat([VS1, vs1[:, :, None, :]], axis=2)
2017
+ VS1 = self.backend.bk_concat([VS1, self.backend.bk_expand_dims(vs1,off_P0)], axis=2)
1939
2018
 
1940
2019
  else: # Cross
1941
2020
  ### Make the convolution I2 * Psi_j3
@@ -1994,13 +2073,13 @@ class funct(FOC.FoCUS):
1994
2073
  p00=self.backend.bk_real(p00)
1995
2074
 
1996
2075
  if P00 is None:
1997
- P00 = p00[:,:,None,:] # Add a dimension for NP00
2076
+ P00 = self.backend.bk_expand_dims(p00,off_P0) # Add a dimension for NP00
1998
2077
  if calc_var:
1999
- VP00 = vp00[:,:,None,:] # Add a dimension for NP00
2078
+ VP00 = self.backend.bk_expand_dims(vp00,off_P0) # Add a dimension for NP00
2000
2079
  else:
2001
- P00 = self.backend.bk_concat([P00, p00[:,:,None,:]], axis=2)
2080
+ P00 = self.backend.bk_concat([P00, self.backend.bk_expand_dims(p00,off_P0)], axis=2)
2002
2081
  if calc_var:
2003
- VP00 = self.backend.bk_concat([VP00, vp00[:,:,None,:]], axis=2)
2082
+ VP00 = self.backend.bk_concat([VP00, self.backend.bk_expand_dims(vp00,off_P0)], axis=2)
2004
2083
 
2005
2084
  #### S1_auto computation
2006
2085
  ### Image 1 : S1 = < M1 >_pix
@@ -2022,14 +2101,15 @@ class funct(FOC.FoCUS):
2022
2101
  self.div_norm(s1,(P1_dic[j3]) ** 0.5)
2023
2102
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2024
2103
  if S1 is None:
2025
- S1 = s1[:, :, None, :] # Add a dimension for NS1
2104
+ S1 = self.backend.bk_expand_dims(s1,off_P0) # Add a dimension for NS1
2026
2105
  if calc_var:
2027
- VS1 = vs1[:, :, None, :] # Add a dimension for NS1
2106
+ VS1 = self.backend.bk_expand_dims(vs1,off_P0) # Add a dimension for NS1
2028
2107
  else:
2029
- S1 = self.backend.bk_concat([S1, s1[:, :, None, :]], axis=2)
2108
+ S1 = self.backend.bk_concat([S1, self.backend.bk_expand_dims(s1,off_P0)], axis=2)
2030
2109
  if calc_var:
2031
- VS1 = self.backend.bk_concat([VS1, vs1[:, :, None, :]], axis=2)
2032
-
2110
+ VS1 = self.backend.bk_concat([VS1,
2111
+ self.backend.bk_expand_dims(vs1,off_P0)], axis=2)
2112
+
2033
2113
  # Initialize dictionaries for |I1*Psi_j| * Psi_j3
2034
2114
  M1convPsi_dic = {}
2035
2115
  if cross:
@@ -2067,19 +2147,19 @@ class funct(FOC.FoCUS):
2067
2147
  else:
2068
2148
  ### Normalize C01 with P00_j [Nbatch, Nmask, Norient_j]
2069
2149
  if norm is not None:
2070
- self.div_norm(c01,(P1_dic[j2][:, :, None, :] *
2071
- P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2150
+ self.div_norm(c01,(self.backend.bk_expand_dims(P1_dic[j2],off_P0) *
2151
+ self.backend.bk_expand_dims(P1_dic[j3],-1)) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2072
2152
 
2073
2153
  ### Store C01 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
2074
2154
  if C01 is None:
2075
- C01 = c01[:,:,None,:,:] # Add a dimension for NC01
2155
+ C01 = self.backend.bk_expand_dims(c01,off_C01) # Add a dimension for NC01
2076
2156
  if calc_var:
2077
- VC01 = vc01[:,:,None,:,:] # Add a dimension for NC01
2157
+ VC01 =self.backend.bk_expand_dims(vc01,off_C01) # Add a dimension for NC01
2078
2158
  else:
2079
- C01 = self.backend.bk_concat([C01, c01[:, :, None, :, :]],
2159
+ C01 = self.backend.bk_concat([C01, self.backend.bk_expand_dims(c01,off_C01)],
2080
2160
  axis=2) # Add a dimension for NC01
2081
2161
  if calc_var:
2082
- VC01 = self.backend.bk_concat([VC01, vc01[:, :, None, :, :]],
2162
+ VC01 = self.backend.bk_concat([VC01, self.backend.bk_expand_dims(vc01,off_C01)],
2083
2163
  axis=2) # Add a dimension for NC01
2084
2164
 
2085
2165
  ### C01_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
@@ -2121,28 +2201,28 @@ class funct(FOC.FoCUS):
2121
2201
  else:
2122
2202
  ### Normalize C01 and C10 with P00_j [Nbatch, Nmask, Norient_j]
2123
2203
  if norm is not None:
2124
- self.div_norm(c01,(P2_dic[j2][:, :, None, :] *
2125
- P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2126
- self.div_norm(c10,(P1_dic[j2][:, :, None, :] *
2127
- P2_dic[j3][:, :, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2]
2204
+ self.div_norm(c01,(self.backend.bk_expand_dims(P2_dic[j2],off_P0) *
2205
+ self.backend.bk_expand_dims(P1_dic[j3],-1)) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2206
+ self.div_norm(c10,(self.backend.bk_expand_dims(P1_dic[j2],off_P0) *
2207
+ self.backend.bk_expand_dims(P2_dic[j3],-1)) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2]
2128
2208
 
2129
2209
  ### Store C01 and C10 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
2130
2210
  if C01 is None:
2131
- C01 = c01[:, :, None, :, :] # Add a dimension for NC01
2211
+ C01 = self.backend.bk_expand_dims(c01,off_C01) # Add a dimension for NC01
2132
2212
  if calc_var:
2133
- VC01 = vc01[:, :, None, :, :] # Add a dimension for NC01
2213
+ VC01 = vself.backend.bk_expand_dims(vc01,off_C01) # Add a dimension for NC01
2134
2214
  else:
2135
- C01 = self.backend.bk_concat([C01,c01[:, :, None, :, :]],axis=2) # Add a dimension for NC01
2215
+ C01 = self.backend.bk_concat([C01, self.backend.bk_expand_dims(c01,off_C01)],axis=2) # Add a dimension for NC01
2136
2216
  if calc_var:
2137
- VC01 = self.backend.bk_concat([VC01,vc01[:, :, None, :, :]],axis=2) # Add a dimension for NC01
2217
+ VC01 =self.backend.bk_concat([VC01, self.backend.bk_expand_dims(vc01,off_C01)],axis=2) # Add a dimension for NC01
2138
2218
  if C10 is None:
2139
- C10 = c10[:, :, None, :, :] # Add a dimension for NC01
2219
+ C10 = self.backend.bk_expand_dims(c10,off_C01) # Add a dimension for NC01
2140
2220
  if calc_var:
2141
- VC10 = vc10[:, :, None, :, :] # Add a dimension for NC01
2221
+ VC10 = self.backend.bk_expand_dims(vc10,off_C01) # Add a dimension for NC01
2142
2222
  else:
2143
- C10 = self.backend.bk_concat([C10,c10[:, :, None, :, :]], axis=2) # Add a dimension for NC01
2223
+ C10 = self.backend.bk_concat([C10, self.backend.bk_expand_dims(c10,off_C01)], axis=2) # Add a dimension for NC01
2144
2224
  if calc_var:
2145
- VC10 = self.backend.bk_concat([VC10,vc10[:, :, None, :, :]], axis=2) # Add a dimension for NC01
2225
+ VC10 = self.backend.bk_concat([VC10, self.backend.bk_expand_dims(vc10,off_C01)], axis=2) # Add a dimension for NC01
2146
2226
 
2147
2227
 
2148
2228
  ##### C11
@@ -2167,18 +2247,18 @@ class funct(FOC.FoCUS):
2167
2247
  else:
2168
2248
  ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
2169
2249
  if norm is not None:
2170
- self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
2171
- P1_dic[j2][:, :, None, :,None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2250
+ self.div_norm(c11,(self.backend.bk_expand_dims(self.backend.bk_expand_dims(P1_dic[j1],off_P0),off_P0) *
2251
+ self.backend.bk_expand_dims(self.backend.bk_expand_dims(P1_dic[j2],off_P0),-1)) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2172
2252
  ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
2173
2253
  if C11 is None:
2174
- C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
2254
+ C11 = self.backend.bk_expand_dims(c11,off_C11) # Add a dimension for NC11
2175
2255
  if calc_var:
2176
- VC11 = vc11[:, :, None, :, :, :] # Add a dimension for NC11
2256
+ VC11 = self.backend.bk_expand_dims(vc11,off_C11) # Add a dimension for NC11
2177
2257
  else:
2178
- C11 = self.backend.bk_concat([C11,c11[:, :, None, :, :, :]],
2258
+ C11 = self.backend.bk_concat([C11,self.backend.bk_expand_dims(c11,off_C11)],
2179
2259
  axis=2) # Add a dimension for NC11
2180
2260
  if calc_var:
2181
- VC11 = self.backend.bk_concat([VC11,vc11[:, :, None, :, :, :]],
2261
+ VC11 = self.backend.bk_concat([VC11,self.backend.bk_expand_dims(vc11,off_C11)],
2182
2262
  axis=2) # Add a dimension for NC11
2183
2263
 
2184
2264
  ### C11_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
@@ -2201,18 +2281,18 @@ class funct(FOC.FoCUS):
2201
2281
  else:
2202
2282
  ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
2203
2283
  if norm is not None:
2204
- self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
2205
- P2_dic[j2][:, :, None, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2284
+ self.div_norm(c11,(self.backend.bk_expand_dims(self.backend.bk_expand_dims(P1_dic[j1],off_P0),off_P0) *
2285
+ self.backend.bk_expand_dims(self.backend.bk_expand_dims(P2_dic[j2],off_P0),-1)) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2206
2286
  ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
2207
2287
  if C11 is None:
2208
- C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
2288
+ C11 = self.backend.bk_expand_dims(c11,off_C11) # Add a dimension for NC11
2209
2289
  if calc_var:
2210
- VC11 = vc11[:, :, None, :, :, :] # Add a dimension for NC11
2290
+ VC11 = self.backend.bk_expand_dims(vc11,off_C11) # Add a dimension for NC11
2211
2291
  else:
2212
- C11 = self.backend.bk_concat([C11,c11[:, :, None, :, :, :]],
2292
+ C11 = self.backend.bk_concat([C11,self.backend.bk_expand_dims(c11,off_C11)],
2213
2293
  axis=2) # Add a dimension for NC11
2214
2294
  if calc_var:
2215
- VC11 = self.backend.bk_concat([VC11,vc11[:, :, None, :, :, :]],
2295
+ VC11 = self.backend.bk_concat([VC11,self.backend.bk_expand_dims(vc11,off_C11)],
2216
2296
  axis=2) # Add a dimension for NC11
2217
2297
 
2218
2298
  ###### Reshape for next iteration on j3
@@ -2298,7 +2378,10 @@ class funct(FOC.FoCUS):
2298
2378
  ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
2299
2379
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
2300
2380
  # cconv, sconv are [Nbatch, Npix_j3, Norient3]
2301
- c01 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(MconvPsi) # [Nbatch, Npix_j3, Norient3, Norient2]
2381
+ if self.use_1D:
2382
+ c01 = conv * self.backend.bk_conjugate(MconvPsi)
2383
+ else:
2384
+ c01 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(MconvPsi) # [Nbatch, Npix_j3, Norient3, Norient2]
2302
2385
 
2303
2386
  ### Apply the mask [Nmask, Npix_j3] and sum over pixels
2304
2387
  if return_data:
@@ -2327,7 +2410,10 @@ class funct(FOC.FoCUS):
2327
2410
 
2328
2411
  ### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
2329
2412
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
2330
- c11 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(self.backend.bk_expand_dims(M2, -1)) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
2413
+ if self.use_1D:
2414
+ c11 = M1 * self.backend.bk_conjugate(M2)
2415
+ else:
2416
+ c11 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(self.backend.bk_expand_dims(M2, -1)) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
2331
2417
 
2332
2418
  ### Apply the mask and sum over pixels
2333
2419
  if return_data:
@@ -1203,7 +1203,9 @@ class funct(FOC.FoCUS):
1203
1203
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
1204
1204
  p00 = conv1 * self.backend.bk_conjugate(conv2)
1205
1205
  # Apply the mask [Nmask, Npix_j3] and average over pixels
1206
+ p00 = self.backend.bk_real(p00)
1206
1207
  p00 = self.backend.bk_reduce_sum(p00*vmask, axis=1)
1208
+ print(p00.shape)
1207
1209
  tmp = self.backend.bk_L1(p00) # [Nbatch, Npix_j3, Norient3]
1208
1210
 
1209
1211
  ### Normalize P00_cross
@@ -1215,9 +1217,9 @@ class funct(FOC.FoCUS):
1215
1217
  p00=self.backend.bk_real(p00)
1216
1218
 
1217
1219
  if P00 is None:
1218
- P00 = p00[:,:,None,:] # Add a dimension for NP00
1220
+ P00 = p00[:,:,None] # Add a dimension for NP00
1219
1221
  else:
1220
- P00 = self.backend.bk_concat([P00, p00[:,:,None,:]], axis=2)
1222
+ P00 = self.backend.bk_concat([P00, p00[:,:,None]], axis=axis+2)
1221
1223
 
1222
1224
  #### S1_auto computation
1223
1225
  ### Image 1 : S1 = < M1 >_pix
@@ -0,0 +1,16 @@
1
+ import foscat.scat_cov as scat
2
+
3
+ class scat_cov1D:
4
+ def __init__(self,p00,s0,s1,s2,s2l,j1,j2,cross=False,backend=None):
5
+
6
+ the_scat=scat(P00, C01, C11, s1=S1, c10=C10,backend=self.backend)
7
+ the_scat.set_bk_type('SCAT_COV1D')
8
+ return the_scat
9
+
10
+ def fill(self,im,nullval=0):
11
+ return self.fill_1d(im,nullval=nullval)
12
+
13
+ class funct(scat.funct):
14
+ def __init__(self, *args, **kwargs):
15
+ # Impose que use_2D=True pour la classe scat
16
+ super(funct, self).__init__(use_1D=True, *args, **kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.47
3
+ Version: 3.1.1
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -18,6 +18,7 @@ src/foscat/scat.py
18
18
  src/foscat/scat1D.py
19
19
  src/foscat/scat2D.py
20
20
  src/foscat/scat_cov.py
21
+ src/foscat/scat_cov1D.old.py
21
22
  src/foscat/scat_cov1D.py
22
23
  src/foscat/scat_cov2D.py
23
24
  src/foscat/scat_cov_map.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes