foscat 3.0.47__py3-none-any.whl → 3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -24,6 +24,7 @@ class FoCUS:
24
24
  TEMPLATE_PATH='data',
25
25
  BACKEND='tensorflow',
26
26
  use_2D=False,
27
+ use_1D=False,
27
28
  return_data=False,
28
29
  JmaxDelta=0,
29
30
  DODIV=False,
@@ -32,7 +33,7 @@ class FoCUS:
32
33
  mpi_size=1,
33
34
  mpi_rank=0):
34
35
 
35
- self.__version__ = '3.0.47'
36
+ self.__version__ = '3.1.1'
36
37
  # P00 coeff for normalization for scat_cov
37
38
  self.TMPFILE_VERSION=TMPFILE_VERSION
38
39
  self.P1_dic = None
@@ -86,6 +87,7 @@ class FoCUS:
86
87
 
87
88
  self.OSTEP=JmaxDelta
88
89
  self.use_2D=use_2D
90
+ self.use_1D=use_1D
89
91
 
90
92
  if isMPI:
91
93
  from mpi4py import MPI
@@ -214,12 +216,14 @@ class FoCUS:
214
216
 
215
217
 
216
218
  w_smooth=w_smooth.flatten()
217
-
219
+ if self.use_1D:
220
+ KERNELSZ=5
221
+
218
222
  self.KERNELSZ=KERNELSZ
219
223
 
220
224
  self.Idx_Neighbours={}
221
225
 
222
- if not self.use_2D:
226
+ if not self.use_2D and not self.use_1D:
223
227
  self.w_smooth = {}
224
228
  for i in range(nstep_max):
225
229
  lout=(2**i)
@@ -239,6 +243,27 @@ class FoCUS:
239
243
  self.ww_Real[lout]=wr
240
244
  self.ww_Imag[lout]=wi
241
245
  self.w_smooth[lout]=ws
246
+ elif self.use_1D==True:
247
+ self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
248
+ self.ww_RealT={}
249
+ self.ww_ImagT={}
250
+ self.ww_SmoothT={}
251
+ if KERNELSZ==5:
252
+ xx=np.arange(5)-2
253
+ w=np.exp(-0.25*(xx)**2)
254
+ c=w*np.cos((xx)*np.pi/2)
255
+ s=w*np.sin((xx)*np.pi/2)
256
+
257
+ w=w/np.sum(w)
258
+ c=c-np.mean(c)
259
+ s=s-np.mean(s)
260
+ r=np.sum(np.sqrt(c*c+s*s))
261
+ c=c/r
262
+ s=s/r
263
+ self.ww_RealT[1]=self.backend.constant(np.array(c).reshape(xx.shape[0],1,1))
264
+ self.ww_ImagT[1]=self.backend.constant(np.array(s).reshape(xx.shape[0],1,1))
265
+ self.ww_SmoothT[1] = self.backend.constant(np.array(w).reshape(xx.shape[0],1,1))
266
+
242
267
  else:
243
268
  self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
244
269
  self.ww_RealT={}
@@ -573,7 +598,7 @@ class FoCUS:
573
598
  tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
574
599
  tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),0:2*(npiy//2),:],[ndata,npix//2,2,npiy//2,2,odata])
575
600
 
576
- res=self.backend.bk_reduce_mean(self.backend.bk_reduce_mean(tim,4),2)
601
+ res=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim,4),2)/4
577
602
 
578
603
  if axis==0:
579
604
  if len(ishape)==2:
@@ -587,6 +612,40 @@ class FoCUS:
587
612
  return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2]+ishape[axis+2:])
588
613
 
589
614
  return self.backend.bk_reshape(res,[npix//2,npiy//2])
615
+ elif self.use_1D:
616
+ ishape=list(im.shape)
617
+ if len(ishape)<axis+1:
618
+ if not self.silent:
619
+ print('Use of 1D scat with data that has less than 1D')
620
+ return None
621
+
622
+ npix=im.shape[axis]
623
+ odata=1
624
+ if len(ishape)>axis+1:
625
+ for k in range(axis+1,len(ishape)):
626
+ odata=odata*ishape[k]
627
+
628
+ ndata=1
629
+ for k in range(axis):
630
+ ndata=ndata*ishape[k]
631
+
632
+ tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
633
+ tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),:],[ndata,npix//2,2,odata])
634
+
635
+ res=self.backend.bk_reduce_mean(tim,2)
636
+
637
+ if axis==0:
638
+ if len(ishape)==1:
639
+ return self.backend.bk_reshape(res,[npix//2])
640
+ else:
641
+ return self.backend.bk_reshape(res,[npix//2]+ishape[axis+1:])
642
+ else:
643
+ if len(ishape)==axis+1:
644
+ return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2])
645
+ else:
646
+ return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2]+ishape[axis+1:])
647
+
648
+ return self.backend.bk_reshape(res,[npix//2])
590
649
 
591
650
  else:
592
651
  shape=list(im.shape)
@@ -653,6 +712,46 @@ class FoCUS:
653
712
  return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty]+ishape[axis+2:])
654
713
 
655
714
  return self.backend.bk_reshape(res,[nout,nouty])
715
+
716
+ elif self.use_1D:
717
+ ishape=list(im.shape)
718
+ if len(ishape)<axis+1:
719
+ if not self.silent:
720
+ print('Use of 1D scat with data that has less than 1D')
721
+ return None
722
+
723
+ if ishape[axis]==nout:
724
+ return im
725
+
726
+ npix=im.shape[axis]
727
+ odata=1
728
+ if len(ishape)>axis+1:
729
+ for k in range(axis+1,len(ishape)):
730
+ odata=odata*ishape[k]
731
+
732
+ ndata=1
733
+ for k in range(axis):
734
+ ndata=ndata*ishape[k]
735
+
736
+ tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
737
+
738
+ while tim.shape[1]!=nout:
739
+ res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
740
+ res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
741
+ tim = self.backend.bk_reshape(self.backend.bk_concat([res1,res2],-2),[ndata,tim.shape[1]*2,odata])
742
+
743
+ if axis==0:
744
+ if len(ishape)==1:
745
+ return self.backend.bk_reshape(tim,[nout])
746
+ else:
747
+ return self.backend.bk_reshape(tim,[nout]+ishape[axis+1:])
748
+ else:
749
+ if len(ishape)==axis+1:
750
+ return self.backend.bk_reshape(tim,ishape[0:axis]+[nout])
751
+ else:
752
+ return self.backend.bk_reshape(tim,ishape[0:axis]+[nout]+ishape[axis+1:])
753
+
754
+ return self.backend.bk_reshape(tim,[nout])
656
755
 
657
756
  else:
658
757
 
@@ -842,7 +941,7 @@ class FoCUS:
842
941
  ndata=ndata*ishape[k]
843
942
 
844
943
  tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
845
-
944
+
846
945
  res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
847
946
  res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
848
947
  res = self.backend.bk_concat([res1,res2],-2)
@@ -1278,66 +1377,117 @@ class FoCUS:
1278
1377
  sum_mask=self.backend.bk_reduce_sum(self.backend.bk_reshape(l_mask,[l_mask.shape[0],np.prod(np.array(l_mask.shape[1:]))]),1)
1279
1378
  if not self.use_2D:
1280
1379
  l_mask=12*nside*nside*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1281
- else:
1380
+ elif self.use_2D:
1282
1381
  l_mask=mask.shape[1]*mask.shape[2]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1283
-
1382
+ else:
1383
+ l_mask=mask.shape[1]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1384
+
1284
1385
  if self.use_2D:
1285
- if self.padding=='VALID' and shape[axis]!=l_mask.shape[1]:
1386
+ if self.padding=='VALID':
1286
1387
  l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1287
1388
  if shape[axis]!=l_mask.shape[1]:
1288
1389
  l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1289
- else:
1290
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1291
1390
 
1292
- # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,...,KERNELSZ//2:-self.KERNELSZ//2,KERNELSZ//2:-self.KERNELSZ//2,NORIENT[,NORIENT]]
1293
- if self.use_2D:
1294
1391
  ichannel=1
1295
1392
  for i in range(axis):
1296
1393
  ichannel*=shape[i]
1297
1394
  ochannel=1
1298
1395
  for i in range(axis+2,len(shape)):
1299
1396
  ochannel*=shape[i]
1300
- l_x=self.backend.bk_reshape(x,[ichannel,shape[axis],shape[axis+1],ochannel])
1301
- oshape=[k for k in shape]
1302
- oshape[axis]=oshape[axis]-self.KERNELSZ+1
1303
- oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
1304
- l_x=self.backend.bk_reshape(l_x[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1305
- else:
1306
- l_x=x
1397
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],shape[axis+1],ochannel])
1398
+
1399
+ if self.padding=='VALID':
1400
+ oshape=[k for k in shape]
1401
+ oshape[axis]=oshape[axis]-self.KERNELSZ+1
1402
+ oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
1403
+ l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1404
+
1405
+ elif self.use_1D:
1406
+ if self.padding=='VALID':
1407
+ l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1408
+ if shape[axis]!=l_mask.shape[1]:
1409
+ l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1410
+
1411
+ ichannel=1
1412
+ for i in range(axis):
1413
+ ichannel*=shape[i]
1414
+ ochannel=1
1415
+ for i in range(axis+1,len(shape)):
1416
+ ochannel*=shape[i]
1417
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1307
1418
 
1419
+ if self.padding=='VALID':
1420
+ oshape=[k for k in shape]
1421
+ oshape[axis]=oshape[axis]-self.KERNELSZ+1
1422
+ l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1423
+ else:
1424
+ ichannel=1
1425
+ for i in range(axis):
1426
+ ichannel*=shape[i]
1427
+ ochannel=1
1428
+ for i in range(axis+1,len(shape)):
1429
+ ochannel*=shape[i]
1430
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1431
+
1308
1432
  # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,1,...,X[,Y],NORIENT[,NORIENT]]
1309
- l_x=self.backend.bk_expand_dims(l_x,1)
1310
-
1311
1433
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,X[,Y]]
1312
1434
  l_mask=self.backend.bk_expand_dims(l_mask,0)
1313
-
1314
- # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
1315
- for i in range(1,axis):
1316
- l_mask=self.backend.bk_expand_dims(l_mask,axis)
1435
+ # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,X[,Y],1]
1436
+ l_mask=self.backend.bk_expand_dims(l_mask,-1)
1317
1437
 
1318
1438
  if l_x.dtype==self.all_cbk_type:
1319
1439
  l_mask=self.backend.bk_complex(l_mask,self.backend.bk_cast(0.0*l_mask))
1320
1440
 
1321
1441
  if self.use_2D:
1442
+ mtmp=l_mask
1443
+ vtmp=l_x
1444
+
1445
+ v1=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp,axis=2),2)
1446
+ v2=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2),2)
1447
+ vh=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp,axis=2),2)
1322
1448
 
1323
- # mask=[1,Nmask,....,X,Y] => mask=[1,Nmask,....,X,Y,....]
1324
- for i in range(axis+2,len(x.shape)):
1325
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1326
-
1327
- shape1=list(l_mask.shape)
1328
- shape2=list(l_x.shape)
1449
+ res=v1/vh
1450
+
1451
+ oshape=[]
1452
+ if axis>0:
1453
+ oshape=oshape+list(x.shape[0:axis])
1454
+ oshape=oshape+[mask.shape[0]]
1455
+ if axis+1<len(x.shape):
1456
+ oshape=oshape+list(x.shape[axis+2:])
1457
+
1458
+ if calc_var:
1459
+ if self.backend.bk_is_complex(vtmp):
1460
+ res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1461
+ -self.backend.bk_real(res)*self.backend.bk_real(res)) + \
1462
+ (self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
1463
+ -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1464
+ else:
1465
+ res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1329
1466
 
1330
- oshape1=shape1[0:axis+1]+[shape1[axis+1]*shape1[axis+2]]+shape1[axis+3:]
1331
- oshape2=shape2[0:axis+1]+[shape2[axis+1]*shape2[axis+2]]+shape2[axis+3:]
1467
+ res=self.backend.bk_reshape(res,oshape)
1468
+ res2=self.backend.bk_reshape(res2,oshape)
1469
+ return res,res2
1470
+ else:
1471
+ res=self.backend.bk_reshape(res,oshape)
1472
+ return res
1332
1473
 
1333
- mtmp=self.backend.bk_reshape(l_mask,oshape1)
1334
- vtmp=self.backend.bk_reshape(l_x,oshape2)
1474
+ elif self.use_1D:
1475
+ mtmp=l_mask
1476
+ vtmp=l_x
1335
1477
 
1336
- v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=axis+1)
1337
- v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=axis+1)
1338
- vh=self.backend.bk_reduce_sum(mtmp,axis=axis+1)
1478
+ v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=2)
1479
+ v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2)
1480
+ vh=self.backend.bk_reduce_sum(mtmp,axis=2)
1339
1481
 
1340
1482
  res=v1/vh
1483
+
1484
+ oshape=[]
1485
+ if axis>0:
1486
+ oshape=oshape+list(x.shape[0:axis])
1487
+ oshape=oshape+[mask.shape[0]]
1488
+ if axis+1<len(x.shape):
1489
+ oshape=oshape+list(x.shape[axis+1:])
1490
+
1341
1491
  if calc_var:
1342
1492
  if self.backend.bk_is_complex(vtmp):
1343
1493
  res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
@@ -1346,19 +1496,29 @@ class FoCUS:
1346
1496
  -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1347
1497
  else:
1348
1498
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1499
+
1500
+
1501
+ res=self.backend.bk_reshape(res,oshape)
1502
+ res2=self.backend.bk_reshape(res2,oshape)
1349
1503
  return res,res2
1350
1504
  else:
1505
+ res=self.backend.bk_reshape(res,oshape)
1351
1506
  return res
1352
- else:
1353
- # mask=[1,Nmask,....,X] => mask=[1,Nmask,....,X,....]
1354
- for i in range(axis+1,len(x.shape)):
1355
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1356
-
1357
- v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=axis+1)
1358
- v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=axis+1)
1359
- vh=self.backend.bk_reduce_sum(l_mask,axis=axis+1)
1507
+
1508
+ else:
1509
+ v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=2)
1510
+ v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=2)
1511
+ vh=self.backend.bk_reduce_sum(l_mask,axis=2)
1360
1512
 
1361
1513
  res=v1/vh
1514
+
1515
+ oshape=[]
1516
+ if axis>0:
1517
+ oshape=oshape+list(x.shape[0:axis])
1518
+ oshape=oshape+[mask.shape[0]]
1519
+ if axis+1<len(x.shape):
1520
+ oshape=oshape+list(x.shape[axis+1:])
1521
+
1362
1522
  if calc_var:
1363
1523
  if self.backend.bk_is_complex(l_x):
1364
1524
  res2=self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
@@ -1367,8 +1527,12 @@ class FoCUS:
1367
1527
  -self.backend.bk_imag(res)*self.backend.bk_imag(res))/self.backend.bk_real(vh))
1368
1528
  else:
1369
1529
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1530
+
1531
+ res=self.backend.bk_reshape(res,oshape)
1532
+ res2=self.backend.bk_reshape(res2,oshape)
1370
1533
  return res,res2
1371
1534
  else:
1535
+ res=self.backend.bk_reshape(res,oshape)
1372
1536
  return res
1373
1537
 
1374
1538
  # ---------------------------------------------−---------
@@ -1485,7 +1649,6 @@ class FoCUS:
1485
1649
  image=self.backend.bk_cast(in_image)
1486
1650
 
1487
1651
  if self.use_2D:
1488
-
1489
1652
  ishape=list(in_image.shape)
1490
1653
  if len(ishape)<axis+2:
1491
1654
  if not self.silent:
@@ -1528,6 +1691,49 @@ class FoCUS:
1528
1691
  return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
1529
1692
 
1530
1693
  return self.backend.bk_reshape(res,[nout,nouty])
1694
+ elif self.use_1D==True:
1695
+ ishape=list(in_image.shape)
1696
+ if len(ishape)<axis+1:
1697
+ if not self.silent:
1698
+ print('Use of 1D scat with data that has less than 1D')
1699
+ return None
1700
+
1701
+ npix=ishape[axis]
1702
+ odata=1
1703
+ if len(ishape)>axis+1:
1704
+ for k in range(axis+1,len(ishape)):
1705
+ odata=odata*ishape[k]
1706
+
1707
+ ndata=1
1708
+ for k in range(axis):
1709
+ ndata=ndata*ishape[k]
1710
+
1711
+ tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
1712
+
1713
+ if self.backend.bk_is_complex(tim):
1714
+ rr1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1715
+ ii1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1716
+ rr2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1717
+ ii2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1718
+ res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
1719
+ else:
1720
+ rr=self.backend.conv1d(tim,self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1721
+ ii=self.backend.conv1d(tim,self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1722
+ res=self.backend.bk_complex(rr,ii)
1723
+
1724
+ if axis==0:
1725
+ if len(ishape)==1:
1726
+ return self.backend.bk_reshape(res,[res.shape[1]])
1727
+ else:
1728
+ return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+2:])
1729
+ else:
1730
+ if len(ishape)==axis+1:
1731
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
1732
+ else:
1733
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1734
+
1735
+ return self.backend.bk_reshape(res,[nout,nouty])
1736
+
1531
1737
 
1532
1738
  else:
1533
1739
  nside=int(np.sqrt(image.shape[axis]//12))
@@ -1649,7 +1855,46 @@ class FoCUS:
1649
1855
  return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]]+ishape[axis+2:])
1650
1856
 
1651
1857
  return self.backend.bk_reshape(res,[nout,nouty])
1858
+ elif self.use_1D:
1859
+
1860
+ ishape=list(in_image.shape)
1861
+ if len(ishape)<axis+1:
1862
+ if not self.silent:
1863
+ print('Use of 1D scat with data that has less than 1D')
1864
+ return None
1652
1865
 
1866
+ npix=ishape[axis]
1867
+ odata=1
1868
+ if len(ishape)>axis+1:
1869
+ for k in range(axis+1,len(ishape)):
1870
+ odata=odata*ishape[k]
1871
+
1872
+ ndata=1
1873
+ for k in range(axis):
1874
+ ndata=ndata*ishape[k]
1875
+
1876
+ tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
1877
+
1878
+ if self.backend.bk_is_complex(tim):
1879
+ rr=self.backend.conv1d(self.backend.bk_real(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1880
+ ii=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1881
+ res=self.backend.bk_complex(rr,ii)
1882
+ else:
1883
+ res=self.backend.conv1d(tim,self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1884
+
1885
+ if axis==0:
1886
+ if len(ishape)==1:
1887
+ return self.backend.bk_reshape(res,[res.shape[1]])
1888
+ else:
1889
+ return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+1:])
1890
+ else:
1891
+ if len(ishape)==axis+1:
1892
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
1893
+ else:
1894
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1895
+
1896
+ return self.backend.bk_reshape(res,[nout,nouty])
1897
+
1653
1898
  else:
1654
1899
  nside=int(np.sqrt(image.shape[axis]//12))
1655
1900
 
foscat/backend.py CHANGED
@@ -289,10 +289,33 @@ class foscat_backend:
289
289
  for k in range(w.shape[2]):
290
290
  for l in range(w.shape[3]):
291
291
  for j in range(res.shape[0]):
292
- tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='fill', fillvalue=0.0)
292
+ tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='symm')
293
293
  res[j,:,:,l]+=tmp
294
294
  del tmp
295
295
  return res
296
+
297
+ def conv1d(self,x,w,strides=[1, 1, 1],padding='SAME'):
298
+ if self.BACKEND==self.TENSORFLOW:
299
+ kx=w.shape[0]
300
+ paddings = self.backend.constant([[0,0],
301
+ [kx//2,kx//2],
302
+ [0,0]])
303
+ tmp=self.backend.pad(x, paddings, "SYMMETRIC")
304
+
305
+ return self.backend.nn.conv1d(tmp,w,
306
+ stride=strides,
307
+ padding="VALID")
308
+ # to be written!!!
309
+ if self.BACKEND==self.TORCH:
310
+ return x
311
+ if self.BACKEND==self.NUMPY:
312
+ res=np.zeros([x.shape[0],x.shape[1],w.shape[2]],dtype=x.dtype)
313
+ for k in range(w.shape[2]):
314
+ for j in range(res.shape[0]):
315
+ tmp=self.scipy.signal.convolve1d(x[j,:,k],w[:,k,l], mode='same', boundary='symm')
316
+ res[j,:,:,l]+=tmp
317
+ del tmp
318
+ return res
296
319
 
297
320
  def bk_threshold(self,x,threshold,greater=True):
298
321