foscat 3.0.46__py3-none-any.whl → 3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -24,6 +24,7 @@ class FoCUS:
24
24
  TEMPLATE_PATH='data',
25
25
  BACKEND='tensorflow',
26
26
  use_2D=False,
27
+ use_1D=False,
27
28
  return_data=False,
28
29
  JmaxDelta=0,
29
30
  DODIV=False,
@@ -32,7 +33,7 @@ class FoCUS:
32
33
  mpi_size=1,
33
34
  mpi_rank=0):
34
35
 
35
- self.__version__ = '3.0.46'
36
+ self.__version__ = '3.1.0'
36
37
  # P00 coeff for normalization for scat_cov
37
38
  self.TMPFILE_VERSION=TMPFILE_VERSION
38
39
  self.P1_dic = None
@@ -86,6 +87,7 @@ class FoCUS:
86
87
 
87
88
  self.OSTEP=JmaxDelta
88
89
  self.use_2D=use_2D
90
+ self.use_1D=use_1D
89
91
 
90
92
  if isMPI:
91
93
  from mpi4py import MPI
@@ -214,12 +216,14 @@ class FoCUS:
214
216
 
215
217
 
216
218
  w_smooth=w_smooth.flatten()
217
-
219
+ if self.use_1D:
220
+ KERNELSZ=5
221
+
218
222
  self.KERNELSZ=KERNELSZ
219
223
 
220
224
  self.Idx_Neighbours={}
221
225
 
222
- if not self.use_2D:
226
+ if not self.use_2D and not self.use_1D:
223
227
  self.w_smooth = {}
224
228
  for i in range(nstep_max):
225
229
  lout=(2**i)
@@ -239,6 +243,21 @@ class FoCUS:
239
243
  self.ww_Real[lout]=wr
240
244
  self.ww_Imag[lout]=wi
241
245
  self.w_smooth[lout]=ws
246
+ elif self.use_1D==True:
247
+ self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
248
+ self.ww_RealT={}
249
+ self.ww_ImagT={}
250
+ self.ww_SmoothT={}
251
+ if KERNELSZ==5:
252
+ xx=np.arange(5)-2
253
+ w=np.exp(-0.25*(xx)**2)
254
+ c=np.cos((xx)*np.pi/2)
255
+ s=np.sin((xx)*np.pi/2)
256
+
257
+ self.ww_RealT[1]=self.backend.constant(np.array(w*c).reshape(xx.shape[0],1,1))
258
+ self.ww_ImagT[1]=self.backend.constant(np.array(w*s).reshape(xx.shape[0],1,1))
259
+ self.ww_SmoothT[1] = self.backend.constant(np.array(w).reshape(xx.shape[0],1,1))
260
+
242
261
  else:
243
262
  self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
244
263
  self.ww_RealT={}
@@ -573,7 +592,7 @@ class FoCUS:
573
592
  tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
574
593
  tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),0:2*(npiy//2),:],[ndata,npix//2,2,npiy//2,2,odata])
575
594
 
576
- res=self.backend.bk_reduce_mean(self.backend.bk_reduce_mean(tim,4),2)
595
+ res=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim,4),2)/4
577
596
 
578
597
  if axis==0:
579
598
  if len(ishape)==2:
@@ -587,6 +606,40 @@ class FoCUS:
587
606
  return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2]+ishape[axis+2:])
588
607
 
589
608
  return self.backend.bk_reshape(res,[npix//2,npiy//2])
609
+ elif self.use_1D:
610
+ ishape=list(im.shape)
611
+ if len(ishape)<axis+1:
612
+ if not self.silent:
613
+ print('Use of 1D scat with data that has less than 1D')
614
+ return None
615
+
616
+ npix=im.shape[axis]
617
+ odata=1
618
+ if len(ishape)>axis+1:
619
+ for k in range(axis+1,len(ishape)):
620
+ odata=odata*ishape[k]
621
+
622
+ ndata=1
623
+ for k in range(axis):
624
+ ndata=ndata*ishape[k]
625
+
626
+ tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
627
+ tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),:],[ndata,npix//2,2,odata])
628
+
629
+ res=self.backend.bk_reduce_mean(tim,2)
630
+
631
+ if axis==0:
632
+ if len(ishape)==1:
633
+ return self.backend.bk_reshape(res,[npix//2])
634
+ else:
635
+ return self.backend.bk_reshape(res,[npix//2]+ishape[axis+1:])
636
+ else:
637
+ if len(ishape)==axis+1:
638
+ return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2])
639
+ else:
640
+ return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2]+ishape[axis+1:])
641
+
642
+ return self.backend.bk_reshape(res,[npix//2])
590
643
 
591
644
  else:
592
645
  shape=list(im.shape)
@@ -653,6 +706,46 @@ class FoCUS:
653
706
  return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty]+ishape[axis+2:])
654
707
 
655
708
  return self.backend.bk_reshape(res,[nout,nouty])
709
+
710
+ elif self.use_1D:
711
+ ishape=list(im.shape)
712
+ if len(ishape)<axis+1:
713
+ if not self.silent:
714
+ print('Use of 1D scat with data that has less than 1D')
715
+ return None
716
+
717
+ if ishape[axis]==nout:
718
+ return im
719
+
720
+ npix=im.shape[axis]
721
+ odata=1
722
+ if len(ishape)>axis+1:
723
+ for k in range(axis+1,len(ishape)):
724
+ odata=odata*ishape[k]
725
+
726
+ ndata=1
727
+ for k in range(axis):
728
+ ndata=ndata*ishape[k]
729
+
730
+ tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
731
+
732
+ while tim.shape[1]!=nout:
733
+ res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
734
+ res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
735
+ tim = self.backend.bk_reshape(self.backend.bk_concat([res1,res2],-2),[ndata,tim.shape[1]*2,odata])
736
+
737
+ if axis==0:
738
+ if len(ishape)==1:
739
+ return self.backend.bk_reshape(tim,[nout])
740
+ else:
741
+ return self.backend.bk_reshape(tim,[nout]+ishape[axis+1:])
742
+ else:
743
+ if len(ishape)==axis+1:
744
+ return self.backend.bk_reshape(tim,ishape[0:axis]+[nout])
745
+ else:
746
+ return self.backend.bk_reshape(tim,ishape[0:axis]+[nout]+ishape[axis+1:])
747
+
748
+ return self.backend.bk_reshape(tim,[nout])
656
749
 
657
750
  else:
658
751
 
@@ -842,7 +935,7 @@ class FoCUS:
842
935
  ndata=ndata*ishape[k]
843
936
 
844
937
  tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
845
-
938
+
846
939
  res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
847
940
  res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
848
941
  res = self.backend.bk_concat([res1,res2],-2)
@@ -1278,66 +1371,117 @@ class FoCUS:
1278
1371
  sum_mask=self.backend.bk_reduce_sum(self.backend.bk_reshape(l_mask,[l_mask.shape[0],np.prod(np.array(l_mask.shape[1:]))]),1)
1279
1372
  if not self.use_2D:
1280
1373
  l_mask=12*nside*nside*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1281
- else:
1374
+ elif self.use_2D:
1282
1375
  l_mask=mask.shape[1]*mask.shape[2]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1283
-
1376
+ else:
1377
+ l_mask=mask.shape[1]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1378
+
1284
1379
  if self.use_2D:
1285
- if self.padding=='VALID' and shape[axis]!=l_mask.shape[1]:
1380
+ if self.padding=='VALID':
1286
1381
  l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1287
1382
  if shape[axis]!=l_mask.shape[1]:
1288
1383
  l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1289
- else:
1290
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1291
1384
 
1292
- # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,...,KERNELSZ//2:-self.KERNELSZ//2,KERNELSZ//2:-self.KERNELSZ//2,NORIENT[,NORIENT]]
1293
- if self.use_2D:
1294
1385
  ichannel=1
1295
1386
  for i in range(axis):
1296
1387
  ichannel*=shape[i]
1297
1388
  ochannel=1
1298
1389
  for i in range(axis+2,len(shape)):
1299
1390
  ochannel*=shape[i]
1300
- l_x=self.backend.bk_reshape(x,[ichannel,shape[axis],shape[axis+1],ochannel])
1301
- oshape=[k for k in shape]
1302
- oshape[axis]=oshape[axis]-self.KERNELSZ+1
1303
- oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
1304
- l_x=self.backend.bk_reshape(l_x[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1305
- else:
1306
- l_x=x
1391
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],shape[axis+1],ochannel])
1392
+
1393
+ if self.padding=='VALID':
1394
+ oshape=[k for k in shape]
1395
+ oshape[axis]=oshape[axis]-self.KERNELSZ+1
1396
+ oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
1397
+ l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1398
+
1399
+ elif self.use_1D:
1400
+ if self.padding=='VALID':
1401
+ l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1402
+ if shape[axis]!=l_mask.shape[1]:
1403
+ l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1404
+
1405
+ ichannel=1
1406
+ for i in range(axis):
1407
+ ichannel*=shape[i]
1408
+ ochannel=1
1409
+ for i in range(axis+1,len(shape)):
1410
+ ochannel*=shape[i]
1411
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1307
1412
 
1413
+ if self.padding=='VALID':
1414
+ oshape=[k for k in shape]
1415
+ oshape[axis]=oshape[axis]-self.KERNELSZ+1
1416
+ l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1417
+ else:
1418
+ ichannel=1
1419
+ for i in range(axis):
1420
+ ichannel*=shape[i]
1421
+ ochannel=1
1422
+ for i in range(axis+1,len(shape)):
1423
+ ochannel*=shape[i]
1424
+ l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1425
+
1308
1426
  # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,1,...,X[,Y],NORIENT[,NORIENT]]
1309
- l_x=self.backend.bk_expand_dims(l_x,1)
1310
-
1311
1427
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,X[,Y]]
1312
1428
  l_mask=self.backend.bk_expand_dims(l_mask,0)
1313
-
1314
- # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
1315
- for i in range(1,axis):
1316
- l_mask=self.backend.bk_expand_dims(l_mask,axis)
1429
+ # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,X[,Y],1]
1430
+ l_mask=self.backend.bk_expand_dims(l_mask,-1)
1317
1431
 
1318
1432
  if l_x.dtype==self.all_cbk_type:
1319
1433
  l_mask=self.backend.bk_complex(l_mask,self.backend.bk_cast(0.0*l_mask))
1320
1434
 
1321
1435
  if self.use_2D:
1436
+ mtmp=l_mask
1437
+ vtmp=l_x
1438
+
1439
+ v1=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp,axis=2),2)
1440
+ v2=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2),2)
1441
+ vh=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp,axis=2),2)
1322
1442
 
1323
- # mask=[1,Nmask,....,X,Y] => mask=[1,Nmask,....,X,Y,....]
1324
- for i in range(axis+2,len(x.shape)):
1325
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1326
-
1327
- shape1=list(l_mask.shape)
1328
- shape2=list(l_x.shape)
1443
+ res=v1/vh
1444
+
1445
+ oshape=[]
1446
+ if axis>0:
1447
+ oshape=oshape+list(x.shape[0:axis])
1448
+ oshape=oshape+[mask.shape[0]]
1449
+ if axis+1<len(x.shape):
1450
+ oshape=oshape+list(x.shape[axis+2:])
1451
+
1452
+ if calc_var:
1453
+ if self.backend.bk_is_complex(vtmp):
1454
+ res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1455
+ -self.backend.bk_real(res)*self.backend.bk_real(res)) + \
1456
+ (self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
1457
+ -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1458
+ else:
1459
+ res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1329
1460
 
1330
- oshape1=shape1[0:axis+1]+[shape1[axis+1]*shape1[axis+2]]+shape1[axis+3:]
1331
- oshape2=shape2[0:axis+1]+[shape2[axis+1]*shape2[axis+2]]+shape2[axis+3:]
1461
+ res=self.backend.bk_reshape(res,oshape)
1462
+ res2=self.backend.bk_reshape(res2,oshape)
1463
+ return res,res2
1464
+ else:
1465
+ res=self.backend.bk_reshape(res,oshape)
1466
+ return res
1332
1467
 
1333
- mtmp=self.backend.bk_reshape(l_mask,oshape1)
1334
- vtmp=self.backend.bk_reshape(l_x,oshape2)
1468
+ elif self.use_1D:
1469
+ mtmp=l_mask
1470
+ vtmp=l_x
1335
1471
 
1336
- v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=axis+1)
1337
- v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=axis+1)
1338
- vh=self.backend.bk_reduce_sum(mtmp,axis=axis+1)
1472
+ v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=2)
1473
+ v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2)
1474
+ vh=self.backend.bk_reduce_sum(mtmp,axis=2)
1339
1475
 
1340
1476
  res=v1/vh
1477
+
1478
+ oshape=[]
1479
+ if axis>0:
1480
+ oshape=oshape+list(x.shape[0:axis])
1481
+ oshape=oshape+[mask.shape[0]]
1482
+ if axis+1<len(x.shape):
1483
+ oshape=oshape+list(x.shape[axis+1:])
1484
+
1341
1485
  if calc_var:
1342
1486
  if self.backend.bk_is_complex(vtmp):
1343
1487
  res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
@@ -1346,19 +1490,29 @@ class FoCUS:
1346
1490
  -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1347
1491
  else:
1348
1492
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1493
+
1494
+
1495
+ res=self.backend.bk_reshape(res,oshape)
1496
+ res2=self.backend.bk_reshape(res2,oshape)
1349
1497
  return res,res2
1350
1498
  else:
1499
+ res=self.backend.bk_reshape(res,oshape)
1351
1500
  return res
1352
- else:
1353
- # mask=[1,Nmask,....,X] => mask=[1,Nmask,....,X,....]
1354
- for i in range(axis+1,len(x.shape)):
1355
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1356
-
1357
- v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=axis+1)
1358
- v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=axis+1)
1359
- vh=self.backend.bk_reduce_sum(l_mask,axis=axis+1)
1501
+
1502
+ else:
1503
+ v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=2)
1504
+ v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=2)
1505
+ vh=self.backend.bk_reduce_sum(l_mask,axis=2)
1360
1506
 
1361
1507
  res=v1/vh
1508
+
1509
+ oshape=[]
1510
+ if axis>0:
1511
+ oshape=oshape+list(x.shape[0:axis])
1512
+ oshape=oshape+[mask.shape[0]]
1513
+ if axis+1<len(x.shape):
1514
+ oshape=oshape+list(x.shape[axis+1:])
1515
+
1362
1516
  if calc_var:
1363
1517
  if self.backend.bk_is_complex(l_x):
1364
1518
  res2=self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
@@ -1367,8 +1521,12 @@ class FoCUS:
1367
1521
  -self.backend.bk_imag(res)*self.backend.bk_imag(res))/self.backend.bk_real(vh))
1368
1522
  else:
1369
1523
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1524
+
1525
+ res=self.backend.bk_reshape(res,oshape)
1526
+ res2=self.backend.bk_reshape(res2,oshape)
1370
1527
  return res,res2
1371
1528
  else:
1529
+ res=self.backend.bk_reshape(res,oshape)
1372
1530
  return res
1373
1531
 
1374
1532
  # ---------------------------------------------−---------
@@ -1485,7 +1643,6 @@ class FoCUS:
1485
1643
  image=self.backend.bk_cast(in_image)
1486
1644
 
1487
1645
  if self.use_2D:
1488
-
1489
1646
  ishape=list(in_image.shape)
1490
1647
  if len(ishape)<axis+2:
1491
1648
  if not self.silent:
@@ -1528,6 +1685,49 @@ class FoCUS:
1528
1685
  return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
1529
1686
 
1530
1687
  return self.backend.bk_reshape(res,[nout,nouty])
1688
+ elif self.use_1D==True:
1689
+ ishape=list(in_image.shape)
1690
+ if len(ishape)<axis+1:
1691
+ if not self.silent:
1692
+ print('Use of 1D scat with data that has less than 1D')
1693
+ return None
1694
+
1695
+ npix=ishape[axis]
1696
+ odata=1
1697
+ if len(ishape)>axis+1:
1698
+ for k in range(axis+1,len(ishape)):
1699
+ odata=odata*ishape[k]
1700
+
1701
+ ndata=1
1702
+ for k in range(axis):
1703
+ ndata=ndata*ishape[k]
1704
+
1705
+ tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
1706
+
1707
+ if self.backend.bk_is_complex(tim):
1708
+ rr1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1709
+ ii1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1710
+ rr2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1711
+ ii2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1712
+ res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
1713
+ else:
1714
+ rr=self.backend.conv1d(tim,self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1715
+ ii=self.backend.conv1d(tim,self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1716
+ res=self.backend.bk_complex(rr,ii)
1717
+
1718
+ if axis==0:
1719
+ if len(ishape)==1:
1720
+ return self.backend.bk_reshape(res,[res.shape[1]])
1721
+ else:
1722
+ return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+2:])
1723
+ else:
1724
+ if len(ishape)==axis+1:
1725
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
1726
+ else:
1727
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1728
+
1729
+ return self.backend.bk_reshape(res,[nout,nouty])
1730
+
1531
1731
 
1532
1732
  else:
1533
1733
  nside=int(np.sqrt(image.shape[axis]//12))
@@ -1649,7 +1849,46 @@ class FoCUS:
1649
1849
  return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]]+ishape[axis+2:])
1650
1850
 
1651
1851
  return self.backend.bk_reshape(res,[nout,nouty])
1852
+ elif self.use_1D:
1853
+
1854
+ ishape=list(in_image.shape)
1855
+ if len(ishape)<axis+1:
1856
+ if not self.silent:
1857
+ print('Use of 1D scat with data that has less than 1D')
1858
+ return None
1652
1859
 
1860
+ npix=ishape[axis]
1861
+ odata=1
1862
+ if len(ishape)>axis+1:
1863
+ for k in range(axis+1,len(ishape)):
1864
+ odata=odata*ishape[k]
1865
+
1866
+ ndata=1
1867
+ for k in range(axis):
1868
+ ndata=ndata*ishape[k]
1869
+
1870
+ tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
1871
+
1872
+ if self.backend.bk_is_complex(tim):
1873
+ rr=self.backend.conv1d(self.backend.bk_real(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1874
+ ii=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1875
+ res=self.backend.bk_complex(rr,ii)
1876
+ else:
1877
+ res=self.backend.conv1d(tim,self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1878
+
1879
+ if axis==0:
1880
+ if len(ishape)==1:
1881
+ return self.backend.bk_reshape(res,[res.shape[1]])
1882
+ else:
1883
+ return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+1:])
1884
+ else:
1885
+ if len(ishape)==axis+1:
1886
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
1887
+ else:
1888
+ return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1889
+
1890
+ return self.backend.bk_reshape(res,[nout,nouty])
1891
+
1653
1892
  else:
1654
1893
  nside=int(np.sqrt(image.shape[axis]//12))
1655
1894
 
foscat/backend.py CHANGED
@@ -271,9 +271,16 @@ class foscat_backend:
271
271
 
272
272
  def conv2d(self,x,w,strides=[1, 1, 1, 1],padding='SAME'):
273
273
  if self.BACKEND==self.TENSORFLOW:
274
- return self.backend.nn.conv2d(x,w,
275
- strides=strides,
276
- padding=padding)
274
+ kx=w.shape[0]
275
+ ky=w.shape[1]
276
+ paddings = self.backend.constant([[0,0],
277
+ [kx//2,kx//2],
278
+ [ky//2,ky//2],
279
+ [0,0]])
280
+ tmp=self.backend.pad(x, paddings, "SYMMETRIC")
281
+ return self.backend.nn.conv2d(tmp,w,
282
+ strides=strides,
283
+ padding="VALID")
277
284
  # to be written!!!
278
285
  if self.BACKEND==self.TORCH:
279
286
  return x
@@ -282,10 +289,33 @@ class foscat_backend:
282
289
  for k in range(w.shape[2]):
283
290
  for l in range(w.shape[3]):
284
291
  for j in range(res.shape[0]):
285
- tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='fill', fillvalue=0.0)
292
+ tmp=self.scipy.signal.convolve2d(x[j,:,:,k],w[:,:,k,l], mode='same', boundary='symm')
286
293
  res[j,:,:,l]+=tmp
287
294
  del tmp
288
295
  return res
296
+
297
+ def conv1d(self,x,w,strides=[1, 1, 1],padding='SAME'):
298
+ if self.BACKEND==self.TENSORFLOW:
299
+ kx=w.shape[0]
300
+ paddings = self.backend.constant([[0,0],
301
+ [kx//2,kx//2],
302
+ [0,0]])
303
+ tmp=self.backend.pad(x, paddings, "SYMMETRIC")
304
+
305
+ return self.backend.nn.conv1d(tmp,w,
306
+ stride=strides,
307
+ padding="VALID")
308
+ # to be written!!!
309
+ if self.BACKEND==self.TORCH:
310
+ return x
311
+ if self.BACKEND==self.NUMPY:
312
+ res=np.zeros([x.shape[0],x.shape[1],w.shape[2]],dtype=x.dtype)
313
+ for k in range(w.shape[2]):
314
+ for j in range(res.shape[0]):
315
+ tmp=self.scipy.signal.convolve1d(x[j,:,k],w[:,k,l], mode='same', boundary='symm')
316
+ res[j,:,:,l]+=tmp
317
+ del tmp
318
+ return res
289
319
 
290
320
  def bk_threshold(self,x,threshold,greater=True):
291
321