foscat 3.0.35__tar.gz → 3.0.40__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {foscat-3.0.35 → foscat-3.0.40}/PKG-INFO +1 -1
- {foscat-3.0.35 → foscat-3.0.40}/setup.py +1 -1
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/FoCUS.py +122 -12
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/GCNN.py +65 -13
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/backend.py +31 -8
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/loss_backend_torch.py +30 -10
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat.py +25 -1
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov.py +50 -15
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/PKG-INFO +1 -1
- {foscat-3.0.35 → foscat-3.0.40}/README.md +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/setup.cfg +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/CNN.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/CircSpline.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/GetGPUinfo.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/Softmax.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/Spline1D.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/Synthesis.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/__init__.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/backend_tens.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/loss_backend_tens.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat1D.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat2D.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov1D.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov2D.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov_map.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov_map2D.py +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/SOURCES.txt +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/dependency_links.txt +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/requires.txt +0 -0
- {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|
|
3
3
|
|
|
4
4
|
setup(
|
|
5
5
|
name='foscat',
|
|
6
|
-
version='3.0.
|
|
6
|
+
version='3.0.40',
|
|
7
7
|
description='Generate synthetic Healpix or 2D data using Cross Scattering Transform' ,
|
|
8
8
|
long_description='Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising. \n A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO. \n Complete doc can be found at https://foscat-documentation.readthedocs.io/en/latest/index.html. \n\n List of developers : J.-M. Delouis, T. Foulquier, L. Mousset, T. Odaka, F. Paul, E. Allys ' ,
|
|
9
9
|
license='MIT',
|
|
@@ -5,7 +5,7 @@ import foscat.backend as bk
|
|
|
5
5
|
from scipy.interpolate import griddata
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
TMPFILE_VERSION='
|
|
8
|
+
TMPFILE_VERSION='V4_0'
|
|
9
9
|
|
|
10
10
|
class FoCUS:
|
|
11
11
|
def __init__(self,
|
|
@@ -32,7 +32,7 @@ class FoCUS:
|
|
|
32
32
|
mpi_size=1,
|
|
33
33
|
mpi_rank=0):
|
|
34
34
|
|
|
35
|
-
self.__version__ = '3.0.
|
|
35
|
+
self.__version__ = '3.0.40'
|
|
36
36
|
# P00 coeff for normalization for scat_cov
|
|
37
37
|
self.TMPFILE_VERSION=TMPFILE_VERSION
|
|
38
38
|
self.P1_dic = None
|
|
@@ -987,6 +987,98 @@ class FoCUS:
|
|
|
987
987
|
tmp=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel**2,self.NORIENT,nside))
|
|
988
988
|
except:
|
|
989
989
|
if self.use_2D==False:
|
|
990
|
+
|
|
991
|
+
if l_kernel==5:
|
|
992
|
+
pw=0.5
|
|
993
|
+
pw2=0.5
|
|
994
|
+
threshold=2E-4
|
|
995
|
+
|
|
996
|
+
elif l_kernel==3:
|
|
997
|
+
pw=1.0/np.sqrt(2)
|
|
998
|
+
pw2=1.0
|
|
999
|
+
threshold=1E-3
|
|
1000
|
+
|
|
1001
|
+
elif l_kernel==7:
|
|
1002
|
+
pw=0.5
|
|
1003
|
+
pw2=0.25
|
|
1004
|
+
threshold=4E-5
|
|
1005
|
+
|
|
1006
|
+
th,ph=hp.pix2ang(nside,np.arange(12*nside**2),nest=True)
|
|
1007
|
+
x,y,z=hp.pix2vec(nside,np.arange(12*nside**2),nest=True)
|
|
1008
|
+
|
|
1009
|
+
t,p=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
|
|
1010
|
+
phi=[p[k]/np.pi*180 for k in range(12*nside*nside)]
|
|
1011
|
+
thi=[t[k]/np.pi*180 for k in range(12*nside*nside)]
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
indice2=np.zeros([12*nside*nside*64,2],dtype='int')
|
|
1015
|
+
indice=np.zeros([12*nside*nside*64*self.NORIENT,2],dtype='int')
|
|
1016
|
+
wav=np.zeros([12*nside*nside*64*self.NORIENT],dtype='complex')
|
|
1017
|
+
wwav=np.zeros([12*nside*nside*64*self.NORIENT],dtype='float')
|
|
1018
|
+
|
|
1019
|
+
iv=0
|
|
1020
|
+
iv2=0
|
|
1021
|
+
for iii in range(12*nside*nside):
|
|
1022
|
+
|
|
1023
|
+
if iii%(nside*nside)==nside*nside-1:
|
|
1024
|
+
if not self.silent:
|
|
1025
|
+
print('Pre-compute nside=%6d %.2f%%'%(nside,100*iii/(12*nside*nside)))
|
|
1026
|
+
|
|
1027
|
+
hidx=hp.query_disc(nside, [x[iii],y[iii],z[iii]], 2*np.pi/nside,nest=True)
|
|
1028
|
+
|
|
1029
|
+
R=hp.Rotator(rot=[phi[iii],-thi[iii]],eulertype='ZYZ')
|
|
1030
|
+
|
|
1031
|
+
t2,p2=R(th[hidx],ph[hidx])
|
|
1032
|
+
|
|
1033
|
+
vec2=hp.ang2vec(t2,p2)
|
|
1034
|
+
|
|
1035
|
+
x2=vec2[:,0]
|
|
1036
|
+
y2=vec2[:,1]
|
|
1037
|
+
z2=vec2[:,2]
|
|
1038
|
+
|
|
1039
|
+
ww=np.exp(-pw2*((nside)**2)*((x2)**2+(y2)**2+(z2-1.0)**2))
|
|
1040
|
+
idx=np.where((ww**2)>threshold)[0]
|
|
1041
|
+
nval2=len(idx)
|
|
1042
|
+
indice2[iv2:iv2+nval2,0]=iii
|
|
1043
|
+
indice2[iv2:iv2+nval2,1]=hidx[idx]
|
|
1044
|
+
wwav[iv2:iv2+nval2]=ww[idx]/np.sum(ww[idx])
|
|
1045
|
+
iv2+=nval2
|
|
1046
|
+
|
|
1047
|
+
for l in range(self.NORIENT):
|
|
1048
|
+
|
|
1049
|
+
angle=l/4.*np.pi-phi[iii]/180.*np.pi*(z[hidx]>0)-(180.0-phi[iii])/180.*np.pi*(z[hidx]<0)
|
|
1050
|
+
|
|
1051
|
+
#posi=2*(0.5-(z[hidx]<0))
|
|
1052
|
+
|
|
1053
|
+
axes=y2*np.cos(angle)-x2*np.sin(angle)
|
|
1054
|
+
wresr=ww*np.cos(pw*axes*(nside)*np.pi)
|
|
1055
|
+
wresi=ww*np.sin(pw*axes*(nside)*np.pi)
|
|
1056
|
+
|
|
1057
|
+
vnorm=(wresr*wresr+wresi*wresi)
|
|
1058
|
+
idx=np.where(vnorm>threshold)[0]
|
|
1059
|
+
|
|
1060
|
+
nval=len(idx)
|
|
1061
|
+
indice[iv:iv+nval,0]=iii*4+l
|
|
1062
|
+
indice[iv:iv+nval,1]=hidx[idx]
|
|
1063
|
+
#print([hidx[k] for k in idx])
|
|
1064
|
+
#print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
|
|
1065
|
+
normr=np.mean(wresr[idx])
|
|
1066
|
+
normi=np.mean(wresi[idx])
|
|
1067
|
+
|
|
1068
|
+
val=wresr[idx]-normr+np.complex(0,1)*(wresi[idx]-normi)
|
|
1069
|
+
val=val/abs(val).sum()
|
|
1070
|
+
|
|
1071
|
+
wav[iv:iv+nval]=val
|
|
1072
|
+
iv+=nval
|
|
1073
|
+
|
|
1074
|
+
indice=indice[:iv,:]
|
|
1075
|
+
wav=wav[:iv]
|
|
1076
|
+
indice2=indice2[:iv2,:]
|
|
1077
|
+
wwav=wwav[:iv2]
|
|
1078
|
+
if not self.silent:
|
|
1079
|
+
print('Kernel Size ',iv/(self.NORIENT*12*nside*nside))
|
|
1080
|
+
"""
|
|
1081
|
+
# OLD VERSION OLD VERSION OLD VERSION (3.0)
|
|
990
1082
|
if self.KERNELSZ*self.KERNELSZ>12*nside*nside:
|
|
991
1083
|
l_kernel=3
|
|
992
1084
|
|
|
@@ -1077,7 +1169,7 @@ class FoCUS:
|
|
|
1077
1169
|
w[:,j,i]=wav[:,i,j]
|
|
1078
1170
|
wav=w.flatten()
|
|
1079
1171
|
wwav=wwav.flatten()
|
|
1080
|
-
|
|
1172
|
+
"""
|
|
1081
1173
|
if not self.silent:
|
|
1082
1174
|
print('Write FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
|
|
1083
1175
|
np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice)
|
|
@@ -1248,10 +1340,10 @@ class FoCUS:
|
|
|
1248
1340
|
res=v1/vh
|
|
1249
1341
|
if calc_var:
|
|
1250
1342
|
if self.backend.bk_is_complex(vtmp):
|
|
1251
|
-
res2=self.backend.
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1343
|
+
res2=self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
|
|
1344
|
+
-self.backend.bk_real(res)*self.backend.bk_real(res)) + \
|
|
1345
|
+
(self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
|
|
1346
|
+
-self.backend.bk_imag(res)*self.backend.bk_imag(res)))
|
|
1255
1347
|
else:
|
|
1256
1348
|
res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
|
|
1257
1349
|
return res,res2
|
|
@@ -1269,10 +1361,10 @@ class FoCUS:
|
|
|
1269
1361
|
res=v1/vh
|
|
1270
1362
|
if calc_var:
|
|
1271
1363
|
if self.backend.bk_is_complex(l_x):
|
|
1272
|
-
res2=self.backend.
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1364
|
+
res2=self.backend.bk_sqrt(self.backend.bk_real(v2)/self.backend.bk_real(vh)
|
|
1365
|
+
-self.backend.bk_real(res)*self.backend.bk_real(res) + \
|
|
1366
|
+
self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
|
|
1367
|
+
-self.backend.bk_imag(res)*self.backend.bk_imag(res))
|
|
1276
1368
|
else:
|
|
1277
1369
|
res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
|
|
1278
1370
|
return res,res2
|
|
@@ -1368,7 +1460,25 @@ class FoCUS:
|
|
|
1368
1460
|
padding=self.padding)
|
|
1369
1461
|
|
|
1370
1462
|
return self.backend.bk_reshape(res,shape+[norient])
|
|
1371
|
-
|
|
1463
|
+
|
|
1464
|
+
def diff_data(self,x,y,is_complex=True,sigma=None):
|
|
1465
|
+
if sigma is None:
|
|
1466
|
+
if is_complex:
|
|
1467
|
+
r=self.backend.bk_square(self.backend.bk_real(x)-self.backend.bk_real(y))
|
|
1468
|
+
i=self.backend.bk_square(self.backend.bk_imag(x)-self.backend.bk_imag(y))
|
|
1469
|
+
return self.backend.bk_reduce_sum(r+i)
|
|
1470
|
+
else:
|
|
1471
|
+
r=self.backend.bk_square(x-y)
|
|
1472
|
+
return self.backend.bk_reduce_sum(r)
|
|
1473
|
+
else:
|
|
1474
|
+
if is_complex:
|
|
1475
|
+
r=self.backend.bk_square((self.backend.bk_real(x)-self.backend.bk_real(y))/sigma)
|
|
1476
|
+
i=self.backend.bk_square((self.backend.bk_imag(x)-self.backend.bk_imag(y))/sigma)
|
|
1477
|
+
return self.backend.bk_reduce_sum(r+i)
|
|
1478
|
+
else:
|
|
1479
|
+
r=self.backend.bk_square((x-y)/sigma)
|
|
1480
|
+
return self.backend.bk_reduce_sum(r)
|
|
1481
|
+
|
|
1372
1482
|
# ---------------------------------------------−---------
|
|
1373
1483
|
def convol(self,in_image,axis=0):
|
|
1374
1484
|
|
|
@@ -14,6 +14,7 @@ class GCNN:
|
|
|
14
14
|
n_chan_out=1,
|
|
15
15
|
nbatch=1,
|
|
16
16
|
SEED=1234,
|
|
17
|
+
hidden=None,
|
|
17
18
|
filename=None):
|
|
18
19
|
|
|
19
20
|
if filename is not None:
|
|
@@ -29,7 +30,11 @@ class GCNN:
|
|
|
29
30
|
self.in_nside=outlist[4]
|
|
30
31
|
self.nbatch=outlist[1]
|
|
31
32
|
self.n_chan_out=outlist[8]
|
|
32
|
-
|
|
33
|
+
if len(outlist[9])>0:
|
|
34
|
+
self.hidden=outlist[9]
|
|
35
|
+
else:
|
|
36
|
+
self.hidden=None
|
|
37
|
+
|
|
33
38
|
self.x=self.scat_operator.backend.bk_cast(outlist[6])
|
|
34
39
|
else:
|
|
35
40
|
self.nscale=nscale
|
|
@@ -46,21 +51,33 @@ class GCNN:
|
|
|
46
51
|
self.KERNELSZ= scat_operator.KERNELSZ
|
|
47
52
|
self.all_type= scat_operator.all_type
|
|
48
53
|
self.in_nside=in_nside
|
|
54
|
+
self.hidden=hidden
|
|
49
55
|
|
|
50
56
|
np.random.seed(SEED)
|
|
51
57
|
self.x=scat_operator.backend.bk_cast(np.random.randn(self.get_number_of_weights())/(self.KERNELSZ*self.KERNELSZ))
|
|
52
58
|
|
|
53
59
|
def save(self,filename):
|
|
60
|
+
|
|
61
|
+
if self.hidden is None:
|
|
62
|
+
tabh=[]
|
|
63
|
+
else:
|
|
64
|
+
tabh=self.hidden
|
|
65
|
+
|
|
66
|
+
www= self.get_weights()
|
|
54
67
|
|
|
68
|
+
if not isinstance(www,np.ndarray):
|
|
69
|
+
www=www.numpy()
|
|
70
|
+
|
|
55
71
|
outlist=[self.chanlist, \
|
|
56
72
|
self.nbatch, \
|
|
57
73
|
self.npar, \
|
|
58
74
|
self.KERNELSZ, \
|
|
59
75
|
self.in_nside, \
|
|
60
76
|
self.nscale, \
|
|
61
|
-
|
|
77
|
+
www, \
|
|
62
78
|
self.all_type, \
|
|
63
|
-
self.n_chan_out
|
|
79
|
+
self.n_chan_out, \
|
|
80
|
+
tabh]
|
|
64
81
|
|
|
65
82
|
myout=open("%s.pkl"%(filename),"wb")
|
|
66
83
|
pickle.dump(outlist,myout)
|
|
@@ -68,10 +85,19 @@ class GCNN:
|
|
|
68
85
|
|
|
69
86
|
def get_number_of_weights(self):
|
|
70
87
|
totnchan=0
|
|
88
|
+
szk=self.KERNELSZ*self.KERNELSZ
|
|
89
|
+
if self.hidden is not None:
|
|
90
|
+
totnchan=totnchan+self.hidden[0]*self.npar
|
|
91
|
+
for i in range(1,len(self.hidden)):
|
|
92
|
+
totnchan=totnchan+self.hidden[i]*self.hidden[i-1]
|
|
93
|
+
totnchan=totnchan+self.hidden[len(self.hidden)-1]*12*self.in_nside**2*self.chanlist[0]
|
|
94
|
+
else:
|
|
95
|
+
totnchan=self.npar*12*self.in_nside**2*self.chanlist[0]
|
|
96
|
+
|
|
71
97
|
for i in range(self.nscale):
|
|
72
|
-
totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]
|
|
73
|
-
|
|
74
|
-
|
|
98
|
+
totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]*szk
|
|
99
|
+
|
|
100
|
+
return totnchan+self.chanlist[i+1]*self.n_chan_out*szk
|
|
75
101
|
|
|
76
102
|
def set_weights(self,x):
|
|
77
103
|
self.x=x
|
|
@@ -83,20 +109,46 @@ class GCNN:
|
|
|
83
109
|
|
|
84
110
|
x=self.x
|
|
85
111
|
|
|
86
|
-
ww=self.scat_operator.backend.bk_reshape(x[0:self.npar*12*self.in_nside**2*self.chanlist[0]], \
|
|
87
|
-
[self.npar,12*self.in_nside**2*self.chanlist[0]])
|
|
88
112
|
|
|
89
113
|
if axis==0:
|
|
90
114
|
nval=1
|
|
91
115
|
else:
|
|
92
116
|
nval=param.shape[0]
|
|
93
|
-
|
|
117
|
+
|
|
118
|
+
nn=0
|
|
94
119
|
im=self.scat_operator.backend.bk_reshape(param,[nval,self.npar])
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
120
|
+
if self.hidden is not None:
|
|
121
|
+
ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.npar*self.hidden[0]], \
|
|
122
|
+
[self.npar,self.hidden[0]])
|
|
123
|
+
im=self.scat_operator.backend.bk_matmul(im,ww)
|
|
124
|
+
im=self.scat_operator.backend.bk_relu(im)
|
|
125
|
+
nn+=self.npar*self.hidden[0]
|
|
126
|
+
|
|
127
|
+
for i in range(1,len(self.hidden)):
|
|
128
|
+
ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.hidden[i]*self.hidden[i-1]], \
|
|
129
|
+
[self.hidden[i-1],self.hidden[i]])
|
|
130
|
+
im=self.scat_operator.backend.bk_matmul(im,ww)
|
|
131
|
+
im=self.scat_operator.backend.bk_relu(im)
|
|
132
|
+
nn+=self.hidden[i]*self.hidden[i-1]
|
|
133
|
+
|
|
134
|
+
ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.hidden[len(self.hidden)-1]*12*self.in_nside**2*self.chanlist[0]], \
|
|
135
|
+
[self.hidden[len(self.hidden)-1],
|
|
136
|
+
12*self.in_nside**2*self.chanlist[0]])
|
|
137
|
+
im=self.scat_operator.backend.bk_matmul(im,ww)
|
|
138
|
+
im=self.scat_operator.backend.bk_reshape(im,[nval,12*self.in_nside**2,self.chanlist[0]])
|
|
139
|
+
im=self.scat_operator.backend.bk_relu(im)
|
|
140
|
+
nn+=self.hidden[len(self.hidden)-1]*12*self.in_nside**2*self.chanlist[0]
|
|
141
|
+
|
|
142
|
+
else:
|
|
143
|
+
ww=self.scat_operator.backend.bk_reshape(x[0:self.npar*12*self.in_nside**2*self.chanlist[0]], \
|
|
144
|
+
[self.npar,12*self.in_nside**2*self.chanlist[0]])
|
|
145
|
+
im=self.scat_operator.backend.bk_matmul(im,ww)
|
|
146
|
+
im=self.scat_operator.backend.bk_reshape(im,[nval,12*self.in_nside**2,self.chanlist[0]])
|
|
147
|
+
im=self.scat_operator.backend.bk_relu(im)
|
|
98
148
|
|
|
99
|
-
|
|
149
|
+
nn=self.npar*12*self.chanlist[0]*self.in_nside**2
|
|
150
|
+
|
|
151
|
+
|
|
100
152
|
for k in range(self.nscale):
|
|
101
153
|
ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]],
|
|
102
154
|
[self.KERNELSZ*self.KERNELSZ,self.chanlist[k],self.chanlist[k+1]])
|
|
@@ -349,6 +349,15 @@ class foscat_backend:
|
|
|
349
349
|
return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
|
|
350
350
|
if self.BACKEND==self.NUMPY:
|
|
351
351
|
return x.flatten()
|
|
352
|
+
|
|
353
|
+
def bk_size(self,x):
|
|
354
|
+
if self.BACKEND==self.TENSORFLOW:
|
|
355
|
+
return self.backend.size(x)
|
|
356
|
+
if self.BACKEND==self.TORCH:
|
|
357
|
+
return x.numel()
|
|
358
|
+
|
|
359
|
+
if self.BACKEND==self.NUMPY:
|
|
360
|
+
return x.size
|
|
352
361
|
|
|
353
362
|
def bk_resize_image(self,x,shape):
|
|
354
363
|
if self.BACKEND==self.TENSORFLOW:
|
|
@@ -368,11 +377,14 @@ class foscat_backend:
|
|
|
368
377
|
xr=self.bk_real(x)
|
|
369
378
|
xi=self.bk_imag(x)
|
|
370
379
|
|
|
371
|
-
r=self.backend.sign(xr)*self.backend.sqrt(
|
|
372
|
-
i=self.backend.sign(xi)*self.backend.sqrt(
|
|
373
|
-
|
|
380
|
+
r=self.backend.sign(xr)*self.backend.sqrt(xr*xr)
|
|
381
|
+
i=self.backend.sign(xi)*self.backend.sqrt(xi*xi)
|
|
382
|
+
if self.BACKEND==self.TORCH:
|
|
383
|
+
return r
|
|
384
|
+
else:
|
|
385
|
+
return self.bk_complex(r,i)
|
|
374
386
|
else:
|
|
375
|
-
return self.backend.sign(x)*self.backend.sqrt(
|
|
387
|
+
return self.backend.sign(x)*self.backend.sqrt(x*x)
|
|
376
388
|
|
|
377
389
|
def bk_square_comp(self,x):
|
|
378
390
|
if x.dtype==self.all_cbk_type:
|
|
@@ -580,6 +592,13 @@ class foscat_backend:
|
|
|
580
592
|
|
|
581
593
|
if self.BACKEND==self.NUMPY:
|
|
582
594
|
return (data.dtype=='complex64' or data.dtype=='complex128')
|
|
595
|
+
|
|
596
|
+
def bk_distcomp(self,data):
|
|
597
|
+
if self.bk_is_complex(data):
|
|
598
|
+
res=self.bk_square(self.bk_real(data))+self.bk_square(self.bk_imag(data))
|
|
599
|
+
return res
|
|
600
|
+
else:
|
|
601
|
+
return self.bk_square(data)
|
|
583
602
|
|
|
584
603
|
def bk_norm(self,data):
|
|
585
604
|
if self.bk_is_complex(data):
|
|
@@ -727,17 +746,21 @@ class foscat_backend:
|
|
|
727
746
|
if self.BACKEND==self.TENSORFLOW:
|
|
728
747
|
return self.backend.math.real(data)
|
|
729
748
|
if self.BACKEND==self.TORCH:
|
|
730
|
-
return
|
|
749
|
+
return data.real
|
|
731
750
|
if self.BACKEND==self.NUMPY:
|
|
732
|
-
return
|
|
751
|
+
return data.real
|
|
733
752
|
|
|
734
753
|
def bk_imag(self,data):
|
|
735
754
|
if self.BACKEND==self.TENSORFLOW:
|
|
736
755
|
return self.backend.math.imag(data)
|
|
737
756
|
if self.BACKEND==self.TORCH:
|
|
738
|
-
|
|
757
|
+
if data.dtype==self.all_cbk_type:
|
|
758
|
+
return data.imag
|
|
759
|
+
else:
|
|
760
|
+
return 0
|
|
761
|
+
|
|
739
762
|
if self.BACKEND==self.NUMPY:
|
|
740
|
-
return
|
|
763
|
+
return data.imag
|
|
741
764
|
|
|
742
765
|
def bk_relu(self,x):
|
|
743
766
|
if self.BACKEND==self.TENSORFLOW:
|
|
@@ -32,30 +32,50 @@ class loss_backend:
|
|
|
32
32
|
if len(x.shape)>1:
|
|
33
33
|
nx=x.shape[0]
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
#sys.stdout.flush()
|
|
35
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
36
|
+
if torch.cuda.is_available():
|
|
37
|
+
with torch.cuda.device((operation.gpupos+self.curr_gpu)%operation.ngpu):
|
|
39
38
|
|
|
40
|
-
|
|
39
|
+
l_x=x.clone().detach().requires_grad_(True)
|
|
40
|
+
|
|
41
|
+
if nx==1:
|
|
42
|
+
ndata=x.shape[0]
|
|
43
|
+
else:
|
|
44
|
+
ndata=x.shape[0]*x.shape[1]
|
|
45
|
+
|
|
46
|
+
if KEEP_TRACK is not None:
|
|
47
|
+
l,linfo=loss_function.eval(l_x,batch,return_all=True)
|
|
48
|
+
else:
|
|
49
|
+
l=loss_function.eval(l_x,batch)
|
|
50
|
+
|
|
51
|
+
l.backward()
|
|
41
52
|
|
|
53
|
+
g=l_x.grad
|
|
54
|
+
|
|
55
|
+
self.curr_gpu=self.curr_gpu+1
|
|
56
|
+
else:
|
|
57
|
+
l_x=x.clone().detach().requires_grad_(True)
|
|
58
|
+
|
|
42
59
|
if nx==1:
|
|
43
60
|
ndata=x.shape[0]
|
|
44
61
|
else:
|
|
45
62
|
ndata=x.shape[0]*x.shape[1]
|
|
46
|
-
|
|
63
|
+
|
|
47
64
|
if KEEP_TRACK is not None:
|
|
48
65
|
l,linfo=loss_function.eval(l_x,batch,return_all=True)
|
|
49
66
|
else:
|
|
67
|
+
"""
|
|
68
|
+
sx=operation.eval(l_x)
|
|
69
|
+
tmp=(sx.C01-1.0)
|
|
70
|
+
|
|
71
|
+
l=operation.backend.bk_reduce_sum(tmp*tmp) #loss_function.eval(l_x,batch)
|
|
72
|
+
"""
|
|
73
|
+
|
|
50
74
|
l=loss_function.eval(l_x,batch)
|
|
51
75
|
|
|
52
76
|
l.backward()
|
|
53
77
|
|
|
54
78
|
g=l_x.grad
|
|
55
|
-
|
|
56
|
-
print(g)
|
|
57
|
-
|
|
58
|
-
self.curr_gpu=self.curr_gpu+1
|
|
59
79
|
|
|
60
80
|
if KEEP_TRACK is not None:
|
|
61
81
|
return l.detach(),g,linfo
|
|
@@ -1038,7 +1038,7 @@ class funct(FOC.FoCUS):
|
|
|
1038
1038
|
return scat(mP00,mS0,mS1,mS2,mS2L,tmp.j1,tmp.j2,backend=self.backend), \
|
|
1039
1039
|
scat(sP00,sS0,sS1,sS2,sS2L,tmp.j1,tmp.j2,backend=self.backend)
|
|
1040
1040
|
|
|
1041
|
-
def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,calc_var=False):
|
|
1041
|
+
def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,calc_var=False,norm=None):
|
|
1042
1042
|
# Check input consistency
|
|
1043
1043
|
if image2 is not None:
|
|
1044
1044
|
if list(image1.shape)!=list(image2.shape):
|
|
@@ -1344,6 +1344,30 @@ class funct(FOC.FoCUS):
|
|
|
1344
1344
|
self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
|
|
1345
1345
|
self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
|
|
1346
1346
|
|
|
1347
|
+
def reduce_distance(self, x,y, sigma=None):
|
|
1348
|
+
|
|
1349
|
+
if isinstance(x, scat):
|
|
1350
|
+
if sigma is None:
|
|
1351
|
+
result=self.diff_data(y.S0,x.S0,is_complex=False)
|
|
1352
|
+
result+=self.diff_data(y.S1,x.S1)
|
|
1353
|
+
result+=self.diff_data(y.P00,x.P00)
|
|
1354
|
+
result+=self.diff_data(y.S2,x.S2)
|
|
1355
|
+
result+=self.diff_data(y.S2L,x.S2L)
|
|
1356
|
+
else:
|
|
1357
|
+
result=self.diff_data(y.S0,x.S0,is_complex=False,sigma=sigma.S0)
|
|
1358
|
+
result+=self.diff_data(y.S1,x.S1,sigma=sigma.S1)
|
|
1359
|
+
result+=self.diff_data(y.P00,x.P00,sigma=sigma.P00)
|
|
1360
|
+
result+=self.diff_data(y.S2,x.S2,sigma=sigma.S2)
|
|
1361
|
+
result+=self.diff_data(y.S2L,x.S2L,sigma=sigma.S2L)
|
|
1362
|
+
|
|
1363
|
+
nval=self.backend.bk_size(x.S0)+self.backend.bk_size(x.P00)+ \
|
|
1364
|
+
self.backend.bk_size(x.S1)+self.backend.bk_size(x.S2)+self.backend.bk_size(x.S2L)
|
|
1365
|
+
|
|
1366
|
+
result/=self.backend.bk_cast(nval)
|
|
1367
|
+
else:
|
|
1368
|
+
return self.backend.bk_reduce_sum(x)
|
|
1369
|
+
return result
|
|
1370
|
+
|
|
1347
1371
|
def reduce_mean(self,x,axis=None):
|
|
1348
1372
|
if axis is None:
|
|
1349
1373
|
tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))+ \
|
|
@@ -1682,6 +1682,10 @@ class funct(FOC.FoCUS):
|
|
|
1682
1682
|
if image2 is not None:
|
|
1683
1683
|
tmpi2=self.ud_grade_2(tmpi2,axis=1)
|
|
1684
1684
|
return cmat,cmat2
|
|
1685
|
+
|
|
1686
|
+
def div_norm(self,complex_value,float_value):
|
|
1687
|
+
return self.backend.bk_complex(self.backend.bk_real(complex_value)/float_value,
|
|
1688
|
+
self.backend.bk_imag(complex_value)/float_value)
|
|
1685
1689
|
|
|
1686
1690
|
def eval(self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False,cmat=None,cmat2=None):
|
|
1687
1691
|
"""
|
|
@@ -1931,7 +1935,7 @@ class funct(FOC.FoCUS):
|
|
|
1931
1935
|
else:
|
|
1932
1936
|
### Normalize S1
|
|
1933
1937
|
if norm is not None:
|
|
1934
|
-
s1
|
|
1938
|
+
self.div_norm(s1,(P1_dic[j3]) ** 0.5)
|
|
1935
1939
|
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
1936
1940
|
if S1 is None:
|
|
1937
1941
|
S1 = s1[:, :, None, :] # Add a dimension for NS1
|
|
@@ -2024,7 +2028,7 @@ class funct(FOC.FoCUS):
|
|
|
2024
2028
|
else:
|
|
2025
2029
|
### Normalize S1
|
|
2026
2030
|
if norm is not None:
|
|
2027
|
-
s1
|
|
2031
|
+
self.div_norm(s1,(P1_dic[j3]) ** 0.5)
|
|
2028
2032
|
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
2029
2033
|
if S1 is None:
|
|
2030
2034
|
S1 = s1[:, :, None, :] # Add a dimension for NS1
|
|
@@ -2072,8 +2076,8 @@ class funct(FOC.FoCUS):
|
|
|
2072
2076
|
else:
|
|
2073
2077
|
### Normalize C01 with P00_j [Nbatch, Nmask, Norient_j]
|
|
2074
2078
|
if norm is not None:
|
|
2075
|
-
c01
|
|
2076
|
-
P1_dic[j3][:, :, :, None]) ** 0.5
|
|
2079
|
+
self.div_norm(c01,(P1_dic[j2][:, :, None, :] *
|
|
2080
|
+
P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
|
|
2077
2081
|
|
|
2078
2082
|
### Store C01 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
|
|
2079
2083
|
if C01 is None:
|
|
@@ -2126,10 +2130,10 @@ class funct(FOC.FoCUS):
|
|
|
2126
2130
|
else:
|
|
2127
2131
|
### Normalize C01 and C10 with P00_j [Nbatch, Nmask, Norient_j]
|
|
2128
2132
|
if norm is not None:
|
|
2129
|
-
c01
|
|
2130
|
-
P1_dic[j3][:, :, :, None]) ** 0.5
|
|
2131
|
-
c10
|
|
2132
|
-
P2_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
|
|
2133
|
+
self.div_norm(c01,(P2_dic[j2][:, :, None, :] *
|
|
2134
|
+
P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
|
|
2135
|
+
self.div_norm(c10,(P1_dic[j2][:, :, None, :] *
|
|
2136
|
+
P2_dic[j3][:, :, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2]
|
|
2133
2137
|
|
|
2134
2138
|
### Store C01 and C10 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
|
|
2135
2139
|
if C01 is None:
|
|
@@ -2172,9 +2176,8 @@ class funct(FOC.FoCUS):
|
|
|
2172
2176
|
else:
|
|
2173
2177
|
### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
|
|
2174
2178
|
if norm is not None:
|
|
2175
|
-
c11
|
|
2176
|
-
P1_dic[j2][:, :, None, :,
|
|
2177
|
-
None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
2179
|
+
self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
|
|
2180
|
+
P1_dic[j2][:, :, None, :,None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
2178
2181
|
### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
|
|
2179
2182
|
if C11 is None:
|
|
2180
2183
|
C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
|
|
@@ -2207,8 +2210,8 @@ class funct(FOC.FoCUS):
|
|
|
2207
2210
|
else:
|
|
2208
2211
|
### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
|
|
2209
2212
|
if norm is not None:
|
|
2210
|
-
c11
|
|
2211
|
-
P2_dic[j2][:, :, None, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
2213
|
+
self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
|
|
2214
|
+
P2_dic[j2][:, :, None, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
2212
2215
|
### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
|
|
2213
2216
|
if C11 is None:
|
|
2214
2217
|
C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
|
|
@@ -2394,6 +2397,39 @@ class funct(FOC.FoCUS):
|
|
|
2394
2397
|
else:
|
|
2395
2398
|
return self.backend.bk_reduce_mean(x)
|
|
2396
2399
|
return result
|
|
2400
|
+
|
|
2401
|
+
|
|
2402
|
+
def reduce_distance(self, x,y, sigma=None):
|
|
2403
|
+
|
|
2404
|
+
if isinstance(x, scat_cov):
|
|
2405
|
+
if sigma is None:
|
|
2406
|
+
result=self.diff_data(y.S0,x.S0,is_complex=False)
|
|
2407
|
+
if x.S1 is not None:
|
|
2408
|
+
result+=self.diff_data(y.S1,x.S1)
|
|
2409
|
+
if x.C10 is not None:
|
|
2410
|
+
result+=self.diff_data(y.C10,x.C10)
|
|
2411
|
+
result+=self.diff_data(y.P00,x.P00)
|
|
2412
|
+
result+=self.diff_data(y.C01,x.C01)
|
|
2413
|
+
result+=self.diff_data(y.C11,x.C11)
|
|
2414
|
+
else:
|
|
2415
|
+
result=self.diff_data(y.S0,x.S0,is_complex=False,sigma=sigma.S0)
|
|
2416
|
+
if x.S1 is not None:
|
|
2417
|
+
result+=self.diff_data(y.S1,x.S1,sigma=sigma.S1)
|
|
2418
|
+
if x.C10 is not None:
|
|
2419
|
+
result+=self.diff_data(y.C10,x.C10,sigma=sigma.C10)
|
|
2420
|
+
result+=self.diff_data(y.P00,x.P00,sigma=sigma.P00)
|
|
2421
|
+
result+=self.diff_data(y.C01,x.C01,sigma=sigma.C01)
|
|
2422
|
+
result+=self.diff_data(y.C11,x.C11,sigma=sigma.C11)
|
|
2423
|
+
nval=self.backend.bk_size(x.S0)+self.backend.bk_size(x.P00)+ \
|
|
2424
|
+
self.backend.bk_size(x.C01)+self.backend.bk_size(x.C11)
|
|
2425
|
+
if x.S1 is not None:
|
|
2426
|
+
nval+=self.backend.bk_size(x.S1)
|
|
2427
|
+
if x.C10 is not None:
|
|
2428
|
+
nval+=self.backend.bk_size(x.C10)
|
|
2429
|
+
result/=self.backend.bk_cast(nval)
|
|
2430
|
+
else:
|
|
2431
|
+
return self.backend.bk_reduce_sum(x)
|
|
2432
|
+
return result
|
|
2397
2433
|
|
|
2398
2434
|
def reduce_sum(self, x):
|
|
2399
2435
|
|
|
@@ -2412,8 +2448,7 @@ class funct(FOC.FoCUS):
|
|
|
2412
2448
|
else:
|
|
2413
2449
|
return self.backend.bk_reduce_sum(x)
|
|
2414
2450
|
return result
|
|
2415
|
-
|
|
2416
|
-
|
|
2451
|
+
|
|
2417
2452
|
def ldiff(self,sig,x):
|
|
2418
2453
|
|
|
2419
2454
|
if x.S1 is None:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|