foscat 3.0.47__tar.gz → 3.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {foscat-3.0.47 → foscat-3.1.0}/PKG-INFO +1 -1
  2. {foscat-3.0.47 → foscat-3.1.0}/setup.py +1 -1
  3. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/FoCUS.py +286 -47
  4. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/backend.py +24 -1
  5. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/scat_cov.py +136 -50
  6. foscat-3.0.47/src/foscat/scat_cov1D.py → foscat-3.1.0/src/foscat/scat_cov1D.old.py +4 -2
  7. foscat-3.1.0/src/foscat/scat_cov1D.py +16 -0
  8. {foscat-3.0.47 → foscat-3.1.0}/src/foscat.egg-info/PKG-INFO +1 -1
  9. {foscat-3.0.47 → foscat-3.1.0}/src/foscat.egg-info/SOURCES.txt +1 -0
  10. {foscat-3.0.47 → foscat-3.1.0}/README.md +0 -0
  11. {foscat-3.0.47 → foscat-3.1.0}/setup.cfg +0 -0
  12. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/CNN.py +0 -0
  13. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/CircSpline.py +0 -0
  14. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/GCNN.py +0 -0
  15. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/GetGPUinfo.py +0 -0
  16. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/Softmax.py +0 -0
  17. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/Spline1D.py +0 -0
  18. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/Synthesis.py +0 -0
  19. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/__init__.py +0 -0
  20. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/backend_tens.py +0 -0
  21. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/loss_backend_tens.py +0 -0
  22. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/loss_backend_torch.py +0 -0
  23. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/scat.py +0 -0
  24. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/scat1D.py +0 -0
  25. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/scat2D.py +0 -0
  26. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/scat_cov2D.py +0 -0
  27. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/scat_cov_map.py +0 -0
  28. {foscat-3.0.47 → foscat-3.1.0}/src/foscat/scat_cov_map2D.py +0 -0
  29. {foscat-3.0.47 → foscat-3.1.0}/src/foscat.egg-info/dependency_links.txt +0 -0
  30. {foscat-3.0.47 → foscat-3.1.0}/src/foscat.egg-info/requires.txt +0 -0
  31. {foscat-3.0.47 → foscat-3.1.0}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.47
3
+ Version: 3.1.0
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
 
4
4
  setup(
5
5
  name='foscat',
6
- version='3.0.47',
6
+ version='3.1.0',
7
7
  description='Generate synthetic Healpix or 2D data using Cross Scattering Transform' ,
8
8
  long_description='Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising. \n A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO. \n Complete doc can be found at https://foscat-documentation.readthedocs.io/en/latest/index.html. \n\n List of developers : J.-M. Delouis, T. Foulquier, L. Mousset, T. Odaka, F. Paul, E. Allys ' ,
9
9
  license='MIT',
@@ -24,6 +24,7 @@ class FoCUS:
24
24
  TEMPLATE_PATH='data',
25
25
  BACKEND='tensorflow',
26
26
  use_2D=False,
27
+ use_1D=False,
27
28
  return_data=False,
28
29
  JmaxDelta=0,
29
30
  DODIV=False,
@@ -32,7 +33,7 @@ class FoCUS:
32
33
  mpi_size=1,
33
34
  mpi_rank=0):
34
35
 
35
- self.__version__ = '3.0.47'
36
+ self.__version__ = '3.1.0'
36
37
  # P00 coeff for normalization for scat_cov
37
38
  self.TMPFILE_VERSION=TMPFILE_VERSION
38
39
  self.P1_dic = None
@@ -86,6 +87,7 @@ class FoCUS:
86
87
 
87
88
  self.OSTEP=JmaxDelta
88
89
  self.use_2D=use_2D
90
+ self.use_1D=use_1D
89
91
 
90
92
  if isMPI:
91
93
  from mpi4py import MPI
@@ -214,12 +216,14 @@ class FoCUS:
214
216
 
215
217
 
216
218
  w_smooth=w_smooth.flatten()
217
-
219
+ if self.use_1D:
220
+ KERNELSZ=5
221
+
218
222
  self.KERNELSZ=KERNELSZ
219
223
 
220
224
  self.Idx_Neighbours={}
221
225
 
222
- if not self.use_2D:
226
+ if not self.use_2D and not self.use_1D:
223
227
  self.w_smooth = {}
224
228
  for i in range(nstep_max):
225
229
  lout=(2**i)
@@ -239,6 +243,21 @@ class FoCUS:
239
243
  self.ww_Real[lout]=wr
240
244
  self.ww_Imag[lout]=wi
241
245
  self.w_smooth[lout]=ws
246
+ elif self.use_1D==True:
247
+ self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
248
+ self.ww_RealT={}
249
+ self.ww_ImagT={}
250
+ self.ww_SmoothT={}
251
+ if KERNELSZ==5:
252
+ xx=np.arange(5)-2
253
+ w=np.exp(-0.25*(xx)**2)
254
+ c=np.cos((xx)*np.pi/2)
255
+ s=np.sin((xx)*np.pi/2)
256
+
257
+ self.ww_RealT[1]=self.backend.constant(np.array(w*c).reshape(xx.shape[0],1,1))
258
+ self.ww_ImagT[1]=self.backend.constant(np.array(w*s).reshape(xx.shape[0],1,1))
259
+ self.ww_SmoothT[1] = self.backend.constant(np.array(w).reshape(xx.shape[0],1,1))
260
+
242
261
  else:
243
262
  self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
244
263
  self.ww_RealT={}
@@ -573,7 +592,7 @@ class FoCUS:
573
592
  tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
574
593
  tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),0:2*(npiy//2),:],[ndata,npix//2,2,npiy//2,2,odata])
575
594
 
576
- res=self.backend.bk_reduce_mean(self.backend.bk_reduce_mean(tim,4),2)
595
+ res=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim,4),2)/4
577
596
 
578
597
  if axis==0:
579
598
  if len(ishape)==2:
@@ -587,6 +606,40 @@ class FoCUS:
587
606
  return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2]+ishape[axis+2:])
588
607
 
589
608
  return self.backend.bk_reshape(res,[npix//2,npiy//2])
609
+ elif self.use_1D:
610
+ ishape=list(im.shape)
611
+ if len(ishape)<axis+1:
612
+ if not self.silent:
613
+ print('Use of 1D scat with data that has less than 1D')
614
+ return None
615
+
616
+ npix=im.shape[axis]
617
+ odata=1
618
+ if len(ishape)>axis+1:
619
+ for k in range(axis+1,len(ishape)):
620
+ odata=odata*ishape[k]
621
+
622
+ ndata=1
623
+ for k in range(axis):
624
+ ndata=ndata*ishape[k]
625
+
626
+ tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
627
+ tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),:],[ndata,npix//2,2,odata])
628
+
629
+ res=self.backend.bk_reduce_mean(tim,2)
630
+
631
+ if axis==0:
632
+ if len(ishape)==1:
633
+ return self.backend.bk_reshape(res,[npix//2])
634
+ else:
635
+ return self.backend.bk_reshape(res,[npix//2]+ishape[axis+1:])
636
+ else:
637
+ if len(ishape)==axis+1:
638
+ return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2])
639
+ else:
640
+ return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2]+ishape[axis+1:])
641
+
642
+ return self.backend.bk_reshape(res,[npix//2])
590
643
 
591
644
  else:
592
645
  shape=list(im.shape)
@@ -653,6 +706,46 @@ class FoCUS:
653
706
  return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty]+ishape[axis+2:])
654
707
 
655
708
  return self.backend.bk_reshape(res,[nout,nouty])
709
+
710
+ elif self.use_1D:
711
+ ishape=list(im.shape)
712
+ if len(ishape)<axis+1:
713
+ if not self.silent:
714
+ print('Use of 1D scat with data that has less than 1D')
715
+ return None
716
+
717
+ if ishape[axis]==nout:
718
+ return im
719
+
720
+ npix=im.shape[axis]
721
+ odata=1
722
+ if len(ishape)>axis+1:
723
+ for k in range(axis+1,len(ishape)):
724
+ odata=odata*ishape[k]
725
+
726
+ ndata=1
727
+ for k in range(axis):
728
+ ndata=ndata*ishape[k]
729
+
730
+ tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
731
+
732
+ while tim.shape[1]!=nout:
733
+ res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
734
+ res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
735
+ tim = self.backend.bk_reshape(self.backend.bk_concat([res1,res2],-2),[ndata,tim.shape[1]*2,odata])
736
+
737
+ if axis==0:
738
+ if len(ishape)==1:
739
+ return self.backend.bk_reshape(tim,[nout])
740
+ else:
741
+ return self.backend.bk_reshape(tim,[nout]+ishape[axis+1:])
742
+ else:
743
+ if len(ishape)==axis+1:
744
+ return self.backend.bk_reshape(tim,ishape[0:axis]+[nout])
745
+ else:
746
+ return self.backend.bk_reshape(tim,ishape[0:axis]+[nout]+ishape[axis+1:])
747
+
748
+ return self.backend.bk_reshape(tim,[nout])
656
749
 
657
750
  else:
658
751
 
@@ -842,7 +935,7 @@ class FoCUS:
842
935
  ndata=ndata*ishape[k]
843
936
 
844
937
  tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
845
-
938
+
846
939
  res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
847
940
  res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
848
941
  res = self.backend.bk_concat([res1,res2],-2)
@@ -1278,66 +1371,117 @@ class FoCUS:
1278
1371
  sum_mask=self.backend.bk_reduce_sum(self.backend.bk_reshape(l_mask,[l_mask.shape[0],np.prod(np.array(l_mask.shape[1:]))]),1)
1279
1372
  if not self.use_2D:
1280
1373
  l_mask=12*nside*nside*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1281
- else:
1374
+ elif self.use_2D:
1282
1375
  l_mask=mask.shape[1]*mask.shape[2]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1283
-
1376
+ else:
1377
+ l_mask=mask.shape[1]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1378
+
1284
1379
  if self.use_2D:
1285
- if self.padding=='VALID' and shape[axis]!=l_mask.shape[1]:
1380
+ if self.padding=='VALID':
1286
1381
  l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1287
1382
  if shape[axis]!=l_mask.shape[1]:
1288
1383
  l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1289
- else:
1290
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1291
1384
 
1292
- # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,...,KERNELSZ//2:-self.KERNELSZ//2,KERNELSZ//2:-self.KERNELSZ//2,NORIENT[,NORIENT]]
1293
- if self.use_2D:
1294
1385
  ichannel=1
1295
1386
  for i in range(axis):
1296
1387
  ichannel*=shape[i]
1297
1388
  ochannel=1
1298
1389
  for i in range(axis+2,len(shape)):
1299
1390
  ochannel*=shape[i]
1300
- l_x=self.backend.bk_reshape(x,[ichannel,shape[axis],shape[axis+1],ochannel])
1301
- oshape=[k for k in shape]
1302
- oshape[axis]=oshape[axis]-self.KERNELSZ+1
1303
- oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
1304
- l_x=self.backend.bk_reshape(l_x[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1305
- else:
1306
- l_x=x
1391
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],shape[axis+1],ochannel])
1392
+
1393
+ if self.padding=='VALID':
1394
+ oshape=[k for k in shape]
1395
+ oshape[axis]=oshape[axis]-self.KERNELSZ+1
1396
+ oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
1397
+ l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1398
+
1399
+ elif self.use_1D:
1400
+ if self.padding=='VALID':
1401
+ l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1402
+ if shape[axis]!=l_mask.shape[1]:
1403
+ l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1404
+
1405
+ ichannel=1
1406
+ for i in range(axis):
1407
+ ichannel*=shape[i]
1408
+ ochannel=1
1409
+ for i in range(axis+1,len(shape)):
1410
+ ochannel*=shape[i]
1411
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1307
1412
 
1413
+ if self.padding=='VALID':
1414
+ oshape=[k for k in shape]
1415
+ oshape[axis]=oshape[axis]-self.KERNELSZ+1
1416
+ l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1417
+ else:
1418
+ ichannel=1
1419
+ for i in range(axis):
1420
+ ichannel*=shape[i]
1421
+ ochannel=1
1422
+ for i in range(axis+1,len(shape)):
1423
+ ochannel*=shape[i]
1424
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1425
+
1308
1426
  # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,1,...,X[,Y],NORIENT[,NORIENT]]
1309
- l_x=self.backend.bk_expand_dims(l_x,1)
1310
-
1311
1427
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,X[,Y]]
1312
1428
  l_mask=self.backend.bk_expand_dims(l_mask,0)
1313
-
1314
- # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
1315
- for i in range(1,axis):
1316
- l_mask=self.backend.bk_expand_dims(l_mask,axis)
1429
+ # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,X[,Y],1]
1430
+ l_mask=self.backend.bk_expand_dims(l_mask,-1)
1317
1431
 
1318
1432
  if l_x.dtype==self.all_cbk_type:
1319
1433
  l_mask=self.backend.bk_complex(l_mask,self.backend.bk_cast(0.0*l_mask))
1320
1434
 
1321
1435
  if self.use_2D:
1436
+ mtmp=l_mask
1437
+ vtmp=l_x
1438
+
1439
+ v1=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp,axis=2),2)
1440
+ v2=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2),2)
1441
+ vh=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp,axis=2),2)
1322
1442
 
1323
- # mask=[1,Nmask,....,X,Y] => mask=[1,Nmask,....,X,Y,....]
1324
- for i in range(axis+2,len(x.shape)):
1325
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1326
-
1327
- shape1=list(l_mask.shape)
1328
- shape2=list(l_x.shape)
1443
+ res=v1/vh
1444
+
1445
+ oshape=[]
1446
+ if axis>0:
1447
+ oshape=oshape+list(x.shape[0:axis])
1448
+ oshape=oshape+[mask.shape[0]]
1449
+ if axis+1<len(x.shape):
1450
+ oshape=oshape+list(x.shape[axis+2:])
1451
+
1452
+ if calc_var:
1453
+ if self.backend.bk_is_complex(vtmp):
1454
+ res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1455
+ -self.backend.bk_real(res)*self.backend.bk_real(res)) + \
1456
+ (self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
1457
+ -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1458
+ else:
1459
+ res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1329
1460
 
1330
- oshape1=shape1[0:axis+1]+[shape1[axis+1]*shape1[axis+2]]+shape1[axis+3:]
1331
- oshape2=shape2[0:axis+1]+[shape2[axis+1]*shape2[axis+2]]+shape2[axis+3:]
1461
+ res=self.backend.bk_reshape(res,oshape)
1462
+ res2=self.backend.bk_reshape(res2,oshape)
1463
+ return res,res2
1464
+ else:
1465
+ res=self.backend.bk_reshape(res,oshape)
1466
+ return res
1332
1467
 
1333
- mtmp=self.backend.bk_reshape(l_mask,oshape1)
1334
- vtmp=self.backend.bk_reshape(l_x,oshape2)
1468
+ elif self.use_1D:
1469
+ mtmp=l_mask
1470
+ vtmp=l_x
1335
1471
 
1336
- v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=axis+1)
1337
- v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=axis+1)
1338
- vh=self.backend.bk_reduce_sum(mtmp,axis=axis+1)
1472
+ v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=2)
1473
+ v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2)
1474
+ vh=self.backend.bk_reduce_sum(mtmp,axis=2)
1339
1475
 
1340
1476
  res=v1/vh
1477
+
1478
+ oshape=[]
1479
+ if axis>0:
1480
+ oshape=oshape+list(x.shape[0:axis])
1481
+ oshape=oshape+[mask.shape[0]]
1482
+ if axis+1<len(x.shape):
1483
+ oshape=oshape+list(x.shape[axis+1:])
1484
+
1341
1485
  if calc_var:
1342
1486
  if self.backend.bk_is_complex(vtmp):
1343
1487
  res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
@@ -1346,19 +1490,29 @@ class FoCUS:
1346
1490
  -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1347
1491
  else:
1348
1492
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1493
+
1494
+
1495
+ res=self.backend.bk_reshape(res,oshape)
1496
+ res2=self.backend.bk_reshape(res2,oshape)
1349
1497
  return res,res2
1350
1498
  else:
1499
+ res=self.backend.bk_reshape(res,oshape)
1351
1500
  return res
1352
- else:
1353
- # mask=[1,Nmask,....,X] => mask=[1,Nmask,....,X,....]
1354
- for i in range(axis+1,len(x.shape)):
1355
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1356
-
1357
- v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=axis+1)
1358
- v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=axis+1)
1359
- vh=self.backend.bk_reduce_sum(l_mask,axis=axis+1)
1501
+
1502
+ else:
1503
+ v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=2)
1504
+ v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=2)
1505
+ vh=self.backend.bk_reduce_sum(l_mask,axis=2)
1360
1506
 
1361
1507
  res=v1/vh
1508
+
1509
+ oshape=[]
1510
+ if axis>0:
1511
+ oshape=oshape+list(x.shape[0:axis])
1512
+ oshape=oshape+[mask.shape[0]]
1513
+ if axis+1<len(x.shape):
1514
+ oshape=oshape+list(x.shape[axis+1:])
1515
+
1362
1516
  if calc_var:
1363
1517
  if self.backend.bk_is_complex(l_x):
1364
1518
  res2=self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
@@ -1367,8 +1521,12 @@ class FoCUS:
1367
1521
  -self.backend.bk_imag(res)*self.backend.bk_imag(res))/self.backend.bk_real(vh))
1368
1522
  else:
1369
1523
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1524
+
1525
+ res=self.backend.bk_reshape(res,oshape)
1526
+ res2=self.backend.bk_reshape(res2,oshape)
1370
1527
  return res,res2
1371
1528
  else:
1529
+ res=self.backend.bk_reshape(res,oshape)
1372
1530
  return res
1373
1531
 
1374
1532
  # ---------------------------------------------−---------
@@ -1485,7 +1643,6 @@ class FoCUS:
1485
1643
  image=self.backend.bk_cast(in_image)
1486
1644
 
1487
1645
  if self.use_2D:
1488
-
1489
1646
  ishape=list(in_image.shape)
1490
1647
  if len(ishape)<axis+2:
1491
1648
  if not self.silent:
@@ -1528,6 +1685,49 @@ class FoCUS:
1528
1685
  return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
1529
1686
 
1530
1687
  return self.backend.bk_reshape(res,[nout,nouty])
1688
+ elif self.use_1D==True:
1689
+ ishape=list(in_image.shape)
1690
+ if len(ishape)<axis+1:
1691
+ if not self.silent:
1692
+ print('Use of 1D scat with data that has less than 1D')
1693
+ return None
1694
+
1695
+ npix=ishape[axis]
1696
+ odata=1
1697
+ if len(ishape)>axis+1:
1698
+ for k in range(axis+1,len(ishape)):
1699
+ odata=odata*ishape[k]
1700
+
1701
+ ndata=1
1702
+ for k in range(axis):
1703
+ ndata=ndata*ishape[k]
1704
+
1705
+ tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
1706
+
1707
+ if self.backend.bk_is_complex(tim):
1708
+ rr1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1709
+ ii1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1710
+ rr2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1711
+ ii2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1712
+ res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
1713
+ else:
1714
+ rr=self.backend.conv1d(tim,self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1715
+ ii=self.backend.conv1d(tim,self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1716
+ res=self.backend.bk_complex(rr,ii)
1717
+
1718
+ if axis==0:
1719
+ if len(ishape)==1:
1720
+ return self.backend.bk_reshape(res,[res.shape[1]])
1721
+ else:
1722
+ return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+2:])
1723
+ else:
1724
+ if len(ishape)==axis+1:
1725
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
1726
+ else:
1727
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1728
+
1729
+ return self.backend.bk_reshape(res,[nout,nouty])
1730
+
1531
1731
 
1532
1732
  else:
1533
1733
  nside=int(np.sqrt(image.shape[axis]//12))
@@ -1649,7 +1849,46 @@ class FoCUS:
1649
1849
  return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]]+ishape[axis+2:])
1650
1850
 
1651
1851
  return self.backend.bk_reshape(res,[nout,nouty])
1852
+ elif self.use_1D:
1853
+
1854
+ ishape=list(in_image.shape)
1855
+ if len(ishape)<axis+1:
1856
+ if not self.silent:
1857
+ print('Use of 1D scat with data that has less than 1D')
1858
+ return None
1652
1859
 
1860
+ npix=ishape[axis]
1861
+ odata=1
1862
+ if len(ishape)>axis+1:
1863
+ for k in range(axis+1,len(ishape)):
1864
+ odata=odata*ishape[k]
1865
+
1866
+ ndata=1
1867
+ for k in range(axis):
1868
+ ndata=ndata*ishape[k]
1869
+
1870
+ tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
1871
+
1872
+ if self.backend.bk_is_complex(tim):
1873
+ rr=self.backend.conv1d(self.backend.bk_real(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1874
+ ii=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1875
+ res=self.backend.bk_complex(rr,ii)
1876
+ else:
1877
+ res=self.backend.conv1d(tim,self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1878
+
1879
+ if axis==0:
1880
+ if len(ishape)==1:
1881
+ return self.backend.bk_reshape(res,[res.shape[1]])
1882
+ else:
1883
+ return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+1:])
1884
+ else:
1885
+ if len(ishape)==axis+1:
1886
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
1887
+ else:
1888
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1889
+
1890
+ return self.backend.bk_reshape(res,[nout,nouty])
1891
+
1653
1892
  else:
1654
1893
  nside=int(np.sqrt(image.shape[axis]//12))
1655
1894
 
@@ -289,10 +289,33 @@ class foscat_backend:
289
289
  for k in range(w.shape[2]):
290
290
  for l in range(w.shape[3]):
291
291
  for j in range(res.shape[0]):
292
- tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='fill', fillvalue=0.0)
292
+ tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='symm')
293
293
  res[j,:,:,l]+=tmp
294
294
  del tmp
295
295
  return res
296
+
297
+ def conv1d(self,x,w,strides=[1, 1, 1],padding='SAME'):
298
+ if self.BACKEND==self.TENSORFLOW:
299
+ kx=w.shape[0]
300
+ paddings = self.backend.constant([[0,0],
301
+ [kx//2,kx//2],
302
+ [0,0]])
303
+ tmp=self.backend.pad(x, paddings, "SYMMETRIC")
304
+
305
+ return self.backend.nn.conv1d(tmp,w,
306
+ stride=strides,
307
+ padding="VALID")
308
+ # to be written!!!
309
+ if self.BACKEND==self.TORCH:
310
+ return x
311
+ if self.BACKEND==self.NUMPY:
312
+ res=np.zeros([x.shape[0],x.shape[1],w.shape[2]],dtype=x.dtype)
313
+ for k in range(w.shape[2]):
314
+ for j in range(res.shape[0]):
315
+ tmp=self.scipy.signal.convolve1d(x[j,:,k],w[:,k,l], mode='same', boundary='symm')
316
+ res[j,:,:,l]+=tmp
317
+ del tmp
318
+ return res
296
319
 
297
320
  def bk_threshold(self,x,threshold,greater=True):
298
321
 
@@ -859,6 +859,30 @@ class scat_cov:
859
859
  tab2nx=tab2nx+['%d'%(i2)]
860
860
  ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
861
861
  n=n+j2[j1==i2].shape[0]-1
862
+ elif len(tmp.shape)==3:
863
+ for i0 in range(tmp.shape[0]):
864
+ for i1 in range(tmp.shape[1]):
865
+ for i2 in range(j1.max()+1):
866
+ dtmp=tmp[i0,i1,j1==i2]
867
+ if norm:
868
+ dtmp=dtmp/(ntmp[i0,i1,i2]*ntmp[i0,i1,j2[j1==i2]])
869
+ if j2[j1==i2].shape[0]==1:
870
+ ax1.plot(j2[j1==i2]+n,dtmp,'.', \
871
+ color=color, lw=lw)
872
+ else:
873
+ if legend and test is None:
874
+ ax1.plot(j2[j1==i2]+n,dtmp, \
875
+ color=color, label=lname, lw=lw)
876
+ test=1
877
+ ax1.plot(j2[j1==i2]+n,dtmp, \
878
+ color=color, lw=lw)
879
+ tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
880
+ tabx=tabx+[k+n for k in j2[j1==i2]]
881
+ tab2x=tab2x+[(j2[j1==i2]+n).mean()]
882
+ tab2nx=tab2nx+['%d'%(i2)]
883
+ ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
884
+ n=n+j2[j1==i2].shape[0]-1
885
+
862
886
  else:
863
887
  for i0 in range(tmp.shape[0]):
864
888
  for i1 in range(tmp.shape[1]):
@@ -951,6 +975,32 @@ class scat_cov:
951
975
  tab2x=tab2x+[(n+nprev-1)/2]
952
976
  tab2nx=tab2nx+['%d'%(i2)]
953
977
  ax1.axvline(n-0.5,ls=':',color='gray')
978
+ elif len(tmp.shape)==3:
979
+ for i0 in range(tmp.shape[0]):
980
+ for i1 in range(tmp.shape[1]):
981
+ for i2 in range(j1.max()+1):
982
+ nprev=n
983
+ for i2b in range(j2[j1==i2].max()+1):
984
+ idx=np.where((j1==i2)*(j2==i2b))[0]
985
+ dtmp=tmp[i0,i1,idx]
986
+ if norm:
987
+ dtmp=dtmp/(ntmp[i0,i1,i2]*ntmp[i0,i1,i2b])
988
+ if len(idx)==1:
989
+ ax1.plot(np.arange(len(idx))+n,dtmp,'.', \
990
+ color=color, lw=lw)
991
+ else:
992
+ if legend and test is None:
993
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
994
+ color=color, label=lname, lw=lw)
995
+ test=1
996
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
997
+ color=color, lw=lw)
998
+ tabnx=tabnx+[r'%d,%d'%(j2[k],j3[k]) for k in idx]
999
+ tabx=tabx+[k+n for k in range(len(idx))]
1000
+ n=n+idx.shape[0]
1001
+ tab2x=tab2x+[(n+nprev-1)/2]
1002
+ tab2nx=tab2nx+['%d'%(i2)]
1003
+ ax1.axvline(n-0.5,ls=':',color='gray')
954
1004
  else:
955
1005
  for i0 in range(tmp.shape[0]):
956
1006
  for i1 in range(tmp.shape[1]):
@@ -1495,6 +1545,8 @@ class funct(FOC.FoCUS):
1495
1545
  def fill(self,im,nullval=hp.UNSEEN):
1496
1546
  if self.use_2D:
1497
1547
  return self.fill_2d(im,nullval=nullval)
1548
+ if self.use_1D:
1549
+ return self.fill_1d(im,nullval=nullval)
1498
1550
  return self.fill_healpy(im,nullval=nullval)
1499
1551
 
1500
1552
  def moments(self,list_scat):
@@ -1748,6 +1800,15 @@ class funct(FOC.FoCUS):
1748
1800
  x1=im_shape[1]
1749
1801
  x2=im_shape[2]
1750
1802
  J = int(np.log(nside-self.KERNELSZ) / np.log(2)) # Number of j scales
1803
+ elif self.use_1D:
1804
+ if len(image1.shape)==2:
1805
+ npix = int(im_shape[1]) # Number of pixels
1806
+ else:
1807
+ npix = int(im_shape[0]) # Number of pixels
1808
+
1809
+ nside=int(npix)
1810
+
1811
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
1751
1812
  else:
1752
1813
  if len(image1.shape)==2:
1753
1814
  npix = int(im_shape[1]) # Number of pixels
@@ -1785,6 +1846,11 @@ class funct(FOC.FoCUS):
1785
1846
  I1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1786
1847
  if cross:
1787
1848
  I2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1849
+ elif self.use_1D:
1850
+ vmask=self.up_grade(vmask,I1.shape[axis]*2,axis=1)
1851
+ I1=self.up_grade(I1,I1.shape[axis]*2,axis=axis)
1852
+ if cross:
1853
+ I2=self.up_grade(I2,I2.shape[axis]*2,axis=axis)
1788
1854
  else:
1789
1855
  I1 = self.up_grade(I1, nside * 2, axis=axis)
1790
1856
  vmask = self.up_grade(vmask, nside * 2, axis=1)
@@ -1798,6 +1864,11 @@ class funct(FOC.FoCUS):
1798
1864
  I1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1799
1865
  if cross:
1800
1866
  I2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1867
+ elif self.use_1D:
1868
+ vmask=self.up_grade(vmask,I1.shape[axis]*4,axis=1)
1869
+ I1=self.up_grade(I1,I1.shape[axis]*4,axis=axis)
1870
+ if cross:
1871
+ I2=self.up_grade(I2,I2.shape[axis]*4,axis=axis)
1801
1872
  else:
1802
1873
  I1 = self.up_grade(I1, nside * 4, axis=axis)
1803
1874
  vmask = self.up_grade(vmask, nside * 4, axis=1)
@@ -1811,6 +1882,14 @@ class funct(FOC.FoCUS):
1811
1882
  # Coefficients
1812
1883
  S1, P00, C01, C11, C10 = None, None, None, None, None
1813
1884
 
1885
+ off_P0=-2
1886
+ off_C01=-3
1887
+ off_C11=-4
1888
+ if self.use_1D:
1889
+ off_P0=-1
1890
+ off_C01=-1
1891
+ off_C11=-1
1892
+
1814
1893
  # Dictionaries for C01 computation
1815
1894
  M1_dic = {} # M stands for Module M1 = |I1 * Psi|
1816
1895
  if cross:
@@ -1842,7 +1921,6 @@ class funct(FOC.FoCUS):
1842
1921
  s0 = self.masked_mean(I1,vmask,axis=1)
1843
1922
  else:
1844
1923
  s0 = self.masked_mean(I1-I2,vmask,axis=1)
1845
-
1846
1924
 
1847
1925
  #### COMPUTE S1, P00, C01 and C11
1848
1926
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
@@ -1887,6 +1965,7 @@ class funct(FOC.FoCUS):
1887
1965
  else:
1888
1966
  p00 = self.masked_mean(M1_square, vmask, axis=1,rank=j3)
1889
1967
 
1968
+
1890
1969
  if cond_init_P1_dic:
1891
1970
  # We fill P1_dic with P00 for normalisation of C01 and C11
1892
1971
  P1_dic[j3] = p00 # [Nbatch, Nmask, Norient3]
@@ -1900,13 +1979,13 @@ class funct(FOC.FoCUS):
1900
1979
  if norm == 'auto': # Normalize P00
1901
1980
  p00 /= P1_dic[j3]
1902
1981
  if P00 is None:
1903
- P00 = p00[:, :, None, :] # Add a dimension for NP00
1982
+ P00 = self.backend.bk_expand_dims(p00,off_P0) # Add a dimension for NP00
1904
1983
  if calc_var:
1905
- VP00 = vp00[:, :, None, :] # Add a dimension for NP00
1984
+ VP00 = self.backend.bk_expand_dims(vp00,off_P0) # Add a dimension for NP00
1906
1985
  else:
1907
- P00 = self.backend.bk_concat([P00, p00[:, :, None, :]], axis=2)
1986
+ P00 = self.backend.bk_concat([P00, self.backend.bk_expand_dims(p00,off_P0)], axis=2)
1908
1987
  if calc_var:
1909
- VP00 = self.backend.bk_concat([VP00, vp00[:, :, None, :]], axis=2)
1988
+ VP00 = self.backend.bk_concat([VP00, self.backend.bk_expand_dims(vp00,off_P0)], axis=2)
1910
1989
 
1911
1990
  #### S1_auto computation
1912
1991
  ### Image 1 : S1 = < M1 >_pix
@@ -1929,13 +2008,13 @@ class funct(FOC.FoCUS):
1929
2008
  self.div_norm(s1,(P1_dic[j3]) ** 0.5)
1930
2009
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
1931
2010
  if S1 is None:
1932
- S1 = s1[:, :, None, :] # Add a dimension for NS1
2011
+ S1 = self.backend.bk_expand_dims(s1,off_P0) # Add a dimension for NS1
1933
2012
  if calc_var:
1934
- VS1 = vs1[:, :, None, :] # Add a dimension for NS1
2013
+ VS1 = self.backend.bk_expand_dims(vs1,off_P0) # Add a dimension for NS1
1935
2014
  else:
1936
- S1 = self.backend.bk_concat([S1, s1[:, :, None, :]], axis=2)
2015
+ S1 = self.backend.bk_concat([S1,self.backend.bk_expand_dims(s1,off_P0)], axis=2)
1937
2016
  if calc_var:
1938
- VS1 = self.backend.bk_concat([VS1, vs1[:, :, None, :]], axis=2)
2017
+ VS1 = self.backend.bk_concat([VS1, self.backend.bk_expand_dims(vs1,off_P0)], axis=2)
1939
2018
 
1940
2019
  else: # Cross
1941
2020
  ### Make the convolution I2 * Psi_j3
@@ -1994,13 +2073,13 @@ class funct(FOC.FoCUS):
1994
2073
  p00=self.backend.bk_real(p00)
1995
2074
 
1996
2075
  if P00 is None:
1997
- P00 = p00[:,:,None,:] # Add a dimension for NP00
2076
+ P00 = self.backend.bk_expand_dims(p00,off_P0) # Add a dimension for NP00
1998
2077
  if calc_var:
1999
- VP00 = vp00[:,:,None,:] # Add a dimension for NP00
2078
+ VP00 = self.backend.bk_expand_dims(vp00,off_P0) # Add a dimension for NP00
2000
2079
  else:
2001
- P00 = self.backend.bk_concat([P00, p00[:,:,None,:]], axis=2)
2080
+ P00 = self.backend.bk_concat([P00, self.backend.bk_expand_dims(p00,off_P0)], axis=2)
2002
2081
  if calc_var:
2003
- VP00 = self.backend.bk_concat([VP00, vp00[:,:,None,:]], axis=2)
2082
+ VP00 = self.backend.bk_concat([VP00, self.backend.bk_expand_dims(vp00,off_P0)], axis=2)
2004
2083
 
2005
2084
  #### S1_auto computation
2006
2085
  ### Image 1 : S1 = < M1 >_pix
@@ -2022,14 +2101,15 @@ class funct(FOC.FoCUS):
2022
2101
  self.div_norm(s1,(P1_dic[j3]) ** 0.5)
2023
2102
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2024
2103
  if S1 is None:
2025
- S1 = s1[:, :, None, :] # Add a dimension for NS1
2104
+ S1 = self.backend.bk_expand_dims(s1,off_P0) # Add a dimension for NS1
2026
2105
  if calc_var:
2027
- VS1 = vs1[:, :, None, :] # Add a dimension for NS1
2106
+ VS1 = self.backend.bk_expand_dims(vs1,off_P0) # Add a dimension for NS1
2028
2107
  else:
2029
- S1 = self.backend.bk_concat([S1, s1[:, :, None, :]], axis=2)
2108
+ S1 = self.backend.bk_concat([S1, self.backend.bk_expand_dims(s1,off_P0)], axis=2)
2030
2109
  if calc_var:
2031
- VS1 = self.backend.bk_concat([VS1, vs1[:, :, None, :]], axis=2)
2032
-
2110
+ VS1 = self.backend.bk_concat([VS1,
2111
+ self.backend.bk_expand_dims(vs1,off_P0)], axis=2)
2112
+
2033
2113
  # Initialize dictionaries for |I1*Psi_j| * Psi_j3
2034
2114
  M1convPsi_dic = {}
2035
2115
  if cross:
@@ -2067,19 +2147,19 @@ class funct(FOC.FoCUS):
2067
2147
  else:
2068
2148
  ### Normalize C01 with P00_j [Nbatch, Nmask, Norient_j]
2069
2149
  if norm is not None:
2070
- self.div_norm(c01,(P1_dic[j2][:, :, None, :] *
2071
- P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2150
+ self.div_norm(c01,(self.backend.bk_expand_dims(P1_dic[j2],off_P0) *
2151
+ self.backend.bk_expand_dims(P1_dic[j3],-1)) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2072
2152
 
2073
2153
  ### Store C01 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
2074
2154
  if C01 is None:
2075
- C01 = c01[:,:,None,:,:] # Add a dimension for NC01
2155
+ C01 = self.backend.bk_expand_dims(c01,off_C01) # Add a dimension for NC01
2076
2156
  if calc_var:
2077
- VC01 = vc01[:,:,None,:,:] # Add a dimension for NC01
2157
+ VC01 =self.backend.bk_expand_dims(vc01,off_C01) # Add a dimension for NC01
2078
2158
  else:
2079
- C01 = self.backend.bk_concat([C01, c01[:, :, None, :, :]],
2159
+ C01 = self.backend.bk_concat([C01, self.backend.bk_expand_dims(c01,off_C01)],
2080
2160
  axis=2) # Add a dimension for NC01
2081
2161
  if calc_var:
2082
- VC01 = self.backend.bk_concat([VC01, vc01[:, :, None, :, :]],
2162
+ VC01 = self.backend.bk_concat([VC01, self.backend.bk_expand_dims(vc01,off_C01)],
2083
2163
  axis=2) # Add a dimension for NC01
2084
2164
 
2085
2165
  ### C01_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
@@ -2121,28 +2201,28 @@ class funct(FOC.FoCUS):
2121
2201
  else:
2122
2202
  ### Normalize C01 and C10 with P00_j [Nbatch, Nmask, Norient_j]
2123
2203
  if norm is not None:
2124
- self.div_norm(c01,(P2_dic[j2][:, :, None, :] *
2125
- P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2126
- self.div_norm(c10,(P1_dic[j2][:, :, None, :] *
2127
- P2_dic[j3][:, :, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2]
2204
+ self.div_norm(c01,(self.backend.bk_expand_dims(P2_dic[j2],off_P0) *
2205
+ self.backend.bk_expand_dims(P1_dic[j3],-1)) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2206
+ self.div_norm(c10,(self.backend.bk_expand_dims(P1_dic[j2],off_P0) *
2207
+ self.backend.bk_expand_dims(P2_dic[j3],-1)) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2]
2128
2208
 
2129
2209
  ### Store C01 and C10 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
2130
2210
  if C01 is None:
2131
- C01 = c01[:, :, None, :, :] # Add a dimension for NC01
2211
+ C01 = self.backend.bk_expand_dims(c01,off_C01) # Add a dimension for NC01
2132
2212
  if calc_var:
2133
- VC01 = vc01[:, :, None, :, :] # Add a dimension for NC01
2213
+ VC01 = vself.backend.bk_expand_dims(vc01,off_C01) # Add a dimension for NC01
2134
2214
  else:
2135
- C01 = self.backend.bk_concat([C01,c01[:, :, None, :, :]],axis=2) # Add a dimension for NC01
2215
+ C01 = self.backend.bk_concat([C01, self.backend.bk_expand_dims(c01,off_C01)],axis=2) # Add a dimension for NC01
2136
2216
  if calc_var:
2137
- VC01 = self.backend.bk_concat([VC01,vc01[:, :, None, :, :]],axis=2) # Add a dimension for NC01
2217
+ VC01 =self.backend.bk_concat([VC01, self.backend.bk_expand_dims(vc01,off_C01)],axis=2) # Add a dimension for NC01
2138
2218
  if C10 is None:
2139
- C10 = c10[:, :, None, :, :] # Add a dimension for NC01
2219
+ C10 = self.backend.bk_expand_dims(c10,off_C01) # Add a dimension for NC01
2140
2220
  if calc_var:
2141
- VC10 = vc10[:, :, None, :, :] # Add a dimension for NC01
2221
+ VC10 = self.backend.bk_expand_dims(vc10,off_C01) # Add a dimension for NC01
2142
2222
  else:
2143
- C10 = self.backend.bk_concat([C10,c10[:, :, None, :, :]], axis=2) # Add a dimension for NC01
2223
+ C10 = self.backend.bk_concat([C10, self.backend.bk_expand_dims(c10,off_C01)], axis=2) # Add a dimension for NC01
2144
2224
  if calc_var:
2145
- VC10 = self.backend.bk_concat([VC10,vc10[:, :, None, :, :]], axis=2) # Add a dimension for NC01
2225
+ VC10 = self.backend.bk_concat([VC10, self.backend.bk_expand_dims(vc10,off_C01)], axis=2) # Add a dimension for NC01
2146
2226
 
2147
2227
 
2148
2228
  ##### C11
@@ -2167,18 +2247,18 @@ class funct(FOC.FoCUS):
2167
2247
  else:
2168
2248
  ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
2169
2249
  if norm is not None:
2170
- self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
2171
- P1_dic[j2][:, :, None, :,None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2250
+ self.div_norm(c11,(self.backend.bk_expand_dims(self.backend.bk_expand_dims(P1_dic[j1],off_P0),off_P0) *
2251
+ self.backend.bk_expand_dims(self.backend.bk_expand_dims(P1_dic[j2],off_P0),-1)) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2172
2252
  ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
2173
2253
  if C11 is None:
2174
- C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
2254
+ C11 = self.backend.bk_expand_dims(c11,off_C11) # Add a dimension for NC11
2175
2255
  if calc_var:
2176
- VC11 = vc11[:, :, None, :, :, :] # Add a dimension for NC11
2256
+ VC11 = self.backend.bk_expand_dims(vc11,off_C11) # Add a dimension for NC11
2177
2257
  else:
2178
- C11 = self.backend.bk_concat([C11,c11[:, :, None, :, :, :]],
2258
+ C11 = self.backend.bk_concat([C11,self.backend.bk_expand_dims(c11,off_C11)],
2179
2259
  axis=2) # Add a dimension for NC11
2180
2260
  if calc_var:
2181
- VC11 = self.backend.bk_concat([VC11,vc11[:, :, None, :, :, :]],
2261
+ VC11 = self.backend.bk_concat([VC11,self.backend.bk_expand_dims(vc11,off_C11)],
2182
2262
  axis=2) # Add a dimension for NC11
2183
2263
 
2184
2264
  ### C11_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
@@ -2201,18 +2281,18 @@ class funct(FOC.FoCUS):
2201
2281
  else:
2202
2282
  ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
2203
2283
  if norm is not None:
2204
- self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
2205
- P2_dic[j2][:, :, None, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2284
+ self.div_norm(c11,(self.backend.bk_expand_dims(self.backend.bk_expand_dims(P1_dic[j1],off_P0),off_P0) *
2285
+ self.backend.bk_expand_dims(self.backend.bk_expand_dims(P2_dic[j2],off_P0),-1)) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2206
2286
  ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
2207
2287
  if C11 is None:
2208
- C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
2288
+ C11 = self.backend.bk_expand_dims(c11,off_C11) # Add a dimension for NC11
2209
2289
  if calc_var:
2210
- VC11 = vc11[:, :, None, :, :, :] # Add a dimension for NC11
2290
+ VC11 = self.backend.bk_expand_dims(vc11,off_C11) # Add a dimension for NC11
2211
2291
  else:
2212
- C11 = self.backend.bk_concat([C11,c11[:, :, None, :, :, :]],
2292
+ C11 = self.backend.bk_concat([C11,self.backend.bk_expand_dims(c11,off_C11)],
2213
2293
  axis=2) # Add a dimension for NC11
2214
2294
  if calc_var:
2215
- VC11 = self.backend.bk_concat([VC11,vc11[:, :, None, :, :, :]],
2295
+ VC11 = self.backend.bk_concat([VC11,self.backend.bk_expand_dims(vc11,off_C11)],
2216
2296
  axis=2) # Add a dimension for NC11
2217
2297
 
2218
2298
  ###### Reshape for next iteration on j3
@@ -2298,7 +2378,10 @@ class funct(FOC.FoCUS):
2298
2378
  ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
2299
2379
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
2300
2380
  # cconv, sconv are [Nbatch, Npix_j3, Norient3]
2301
- c01 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(MconvPsi) # [Nbatch, Npix_j3, Norient3, Norient2]
2381
+ if self.use_1D:
2382
+ c01 = conv * self.backend.bk_conjugate(MconvPsi)
2383
+ else:
2384
+ c01 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(MconvPsi) # [Nbatch, Npix_j3, Norient3, Norient2]
2302
2385
 
2303
2386
  ### Apply the mask [Nmask, Npix_j3] and sum over pixels
2304
2387
  if return_data:
@@ -2327,7 +2410,10 @@ class funct(FOC.FoCUS):
2327
2410
 
2328
2411
  ### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
2329
2412
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
2330
- c11 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(self.backend.bk_expand_dims(M2, -1)) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
2413
+ if self.use_1D:
2414
+ c11 = M1 * self.backend.bk_conjugate(M2)
2415
+ else:
2416
+ c11 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(self.backend.bk_expand_dims(M2, -1)) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
2331
2417
 
2332
2418
  ### Apply the mask and sum over pixels
2333
2419
  if return_data:
@@ -1203,7 +1203,9 @@ class funct(FOC.FoCUS):
1203
1203
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
1204
1204
  p00 = conv1 * self.backend.bk_conjugate(conv2)
1205
1205
  # Apply the mask [Nmask, Npix_j3] and average over pixels
1206
+ p00 = self.backend.bk_real(p00)
1206
1207
  p00 = self.backend.bk_reduce_sum(p00*vmask, axis=1)
1208
+ print(p00.shape)
1207
1209
  tmp = self.backend.bk_L1(p00) # [Nbatch, Npix_j3, Norient3]
1208
1210
 
1209
1211
  ### Normalize P00_cross
@@ -1215,9 +1217,9 @@ class funct(FOC.FoCUS):
1215
1217
  p00=self.backend.bk_real(p00)
1216
1218
 
1217
1219
  if P00 is None:
1218
- P00 = p00[:,:,None,:] # Add a dimension for NP00
1220
+ P00 = p00[:,:,None] # Add a dimension for NP00
1219
1221
  else:
1220
- P00 = self.backend.bk_concat([P00, p00[:,:,None,:]], axis=2)
1222
+ P00 = self.backend.bk_concat([P00, p00[:,:,None]], axis=axis+2)
1221
1223
 
1222
1224
  #### S1_auto computation
1223
1225
  ### Image 1 : S1 = < M1 >_pix
@@ -0,0 +1,16 @@
1
+ import foscat.scat_cov as scat
2
+
3
+ class scat_cov1D:
4
+ def __init__(self,p00,s0,s1,s2,s2l,j1,j2,cross=False,backend=None):
5
+
6
+ the_scat=scat(P00, C01, C11, s1=S1, c10=C10,backend=self.backend)
7
+ the_scat.set_bk_type('SCAT_COV1D')
8
+ return the_scat
9
+
10
+ def fill(self,im,nullval=0):
11
+ return self.fill_1d(im,nullval=nullval)
12
+
13
+ class funct(scat.funct):
14
+ def __init__(self, *args, **kwargs):
15
+ # Impose que use_2D=True pour la classe scat
16
+ super(funct, self).__init__(use_1D=True, *args, **kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.47
3
+ Version: 3.1.0
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -18,6 +18,7 @@ src/foscat/scat.py
18
18
  src/foscat/scat1D.py
19
19
  src/foscat/scat2D.py
20
20
  src/foscat/scat_cov.py
21
+ src/foscat/scat_cov1D.old.py
21
22
  src/foscat/scat_cov1D.py
22
23
  src/foscat/scat_cov2D.py
23
24
  src/foscat/scat_cov_map.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes