foscat 3.0.35__tar.gz → 3.0.40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {foscat-3.0.35 → foscat-3.0.40}/PKG-INFO +1 -1
  2. {foscat-3.0.35 → foscat-3.0.40}/setup.py +1 -1
  3. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/FoCUS.py +122 -12
  4. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/GCNN.py +65 -13
  5. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/backend.py +31 -8
  6. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/loss_backend_torch.py +30 -10
  7. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat.py +25 -1
  8. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov.py +50 -15
  9. {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/PKG-INFO +1 -1
  10. {foscat-3.0.35 → foscat-3.0.40}/README.md +0 -0
  11. {foscat-3.0.35 → foscat-3.0.40}/setup.cfg +0 -0
  12. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/CNN.py +0 -0
  13. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/CircSpline.py +0 -0
  14. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/GetGPUinfo.py +0 -0
  15. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/Softmax.py +0 -0
  16. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/Spline1D.py +0 -0
  17. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/Synthesis.py +0 -0
  18. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/__init__.py +0 -0
  19. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/backend_tens.py +0 -0
  20. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/loss_backend_tens.py +0 -0
  21. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat1D.py +0 -0
  22. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat2D.py +0 -0
  23. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov1D.py +0 -0
  24. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov2D.py +0 -0
  25. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov_map.py +0 -0
  26. {foscat-3.0.35 → foscat-3.0.40}/src/foscat/scat_cov_map2D.py +0 -0
  27. {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/SOURCES.txt +0 -0
  28. {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/dependency_links.txt +0 -0
  29. {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/requires.txt +0 -0
  30. {foscat-3.0.35 → foscat-3.0.40}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.35
3
+ Version: 3.0.40
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
 
4
4
  setup(
5
5
  name='foscat',
6
- version='3.0.35',
6
+ version='3.0.40',
7
7
  description='Generate synthetic Healpix or 2D data using Cross Scattering Transform' ,
8
8
  long_description='Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising. \n A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO. \n Complete doc can be found at https://foscat-documentation.readthedocs.io/en/latest/index.html. \n\n List of developers : J.-M. Delouis, T. Foulquier, L. Mousset, T. Odaka, F. Paul, E. Allys ' ,
9
9
  license='MIT',
@@ -5,7 +5,7 @@ import foscat.backend as bk
5
5
  from scipy.interpolate import griddata
6
6
 
7
7
 
8
- TMPFILE_VERSION='V3_0'
8
+ TMPFILE_VERSION='V4_0'
9
9
 
10
10
  class FoCUS:
11
11
  def __init__(self,
@@ -32,7 +32,7 @@ class FoCUS:
32
32
  mpi_size=1,
33
33
  mpi_rank=0):
34
34
 
35
- self.__version__ = '3.0.35'
35
+ self.__version__ = '3.0.40'
36
36
  # P00 coeff for normalization for scat_cov
37
37
  self.TMPFILE_VERSION=TMPFILE_VERSION
38
38
  self.P1_dic = None
@@ -987,6 +987,98 @@ class FoCUS:
987
987
  tmp=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel**2,self.NORIENT,nside))
988
988
  except:
989
989
  if self.use_2D==False:
990
+
991
+ if l_kernel==5:
992
+ pw=0.5
993
+ pw2=0.5
994
+ threshold=2E-4
995
+
996
+ elif l_kernel==3:
997
+ pw=1.0/np.sqrt(2)
998
+ pw2=1.0
999
+ threshold=1E-3
1000
+
1001
+ elif l_kernel==7:
1002
+ pw=0.5
1003
+ pw2=0.25
1004
+ threshold=4E-5
1005
+
1006
+ th,ph=hp.pix2ang(nside,np.arange(12*nside**2),nest=True)
1007
+ x,y,z=hp.pix2vec(nside,np.arange(12*nside**2),nest=True)
1008
+
1009
+ t,p=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
1010
+ phi=[p[k]/np.pi*180 for k in range(12*nside*nside)]
1011
+ thi=[t[k]/np.pi*180 for k in range(12*nside*nside)]
1012
+
1013
+
1014
+ indice2=np.zeros([12*nside*nside*64,2],dtype='int')
1015
+ indice=np.zeros([12*nside*nside*64*self.NORIENT,2],dtype='int')
1016
+ wav=np.zeros([12*nside*nside*64*self.NORIENT],dtype='complex')
1017
+ wwav=np.zeros([12*nside*nside*64*self.NORIENT],dtype='float')
1018
+
1019
+ iv=0
1020
+ iv2=0
1021
+ for iii in range(12*nside*nside):
1022
+
1023
+ if iii%(nside*nside)==nside*nside-1:
1024
+ if not self.silent:
1025
+ print('Pre-compute nside=%6d %.2f%%'%(nside,100*iii/(12*nside*nside)))
1026
+
1027
+ hidx=hp.query_disc(nside, [x[iii],y[iii],z[iii]], 2*np.pi/nside,nest=True)
1028
+
1029
+ R=hp.Rotator(rot=[phi[iii],-thi[iii]],eulertype='ZYZ')
1030
+
1031
+ t2,p2=R(th[hidx],ph[hidx])
1032
+
1033
+ vec2=hp.ang2vec(t2,p2)
1034
+
1035
+ x2=vec2[:,0]
1036
+ y2=vec2[:,1]
1037
+ z2=vec2[:,2]
1038
+
1039
+ ww=np.exp(-pw2*((nside)**2)*((x2)**2+(y2)**2+(z2-1.0)**2))
1040
+ idx=np.where((ww**2)>threshold)[0]
1041
+ nval2=len(idx)
1042
+ indice2[iv2:iv2+nval2,0]=iii
1043
+ indice2[iv2:iv2+nval2,1]=hidx[idx]
1044
+ wwav[iv2:iv2+nval2]=ww[idx]/np.sum(ww[idx])
1045
+ iv2+=nval2
1046
+
1047
+ for l in range(self.NORIENT):
1048
+
1049
+ angle=l/4.*np.pi-phi[iii]/180.*np.pi*(z[hidx]>0)-(180.0-phi[iii])/180.*np.pi*(z[hidx]<0)
1050
+
1051
+ #posi=2*(0.5-(z[hidx]<0))
1052
+
1053
+ axes=y2*np.cos(angle)-x2*np.sin(angle)
1054
+ wresr=ww*np.cos(pw*axes*(nside)*np.pi)
1055
+ wresi=ww*np.sin(pw*axes*(nside)*np.pi)
1056
+
1057
+ vnorm=(wresr*wresr+wresi*wresi)
1058
+ idx=np.where(vnorm>threshold)[0]
1059
+
1060
+ nval=len(idx)
1061
+ indice[iv:iv+nval,0]=iii*4+l
1062
+ indice[iv:iv+nval,1]=hidx[idx]
1063
+ #print([hidx[k] for k in idx])
1064
+ #print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1065
+ normr=np.mean(wresr[idx])
1066
+ normi=np.mean(wresi[idx])
1067
+
1068
+ val=wresr[idx]-normr+np.complex(0,1)*(wresi[idx]-normi)
1069
+ val=val/abs(val).sum()
1070
+
1071
+ wav[iv:iv+nval]=val
1072
+ iv+=nval
1073
+
1074
+ indice=indice[:iv,:]
1075
+ wav=wav[:iv]
1076
+ indice2=indice2[:iv2,:]
1077
+ wwav=wwav[:iv2]
1078
+ if not self.silent:
1079
+ print('Kernel Size ',iv/(self.NORIENT*12*nside*nside))
1080
+ """
1081
+ # OLD VERSION OLD VERSION OLD VERSION (3.0)
990
1082
  if self.KERNELSZ*self.KERNELSZ>12*nside*nside:
991
1083
  l_kernel=3
992
1084
 
@@ -1077,7 +1169,7 @@ class FoCUS:
1077
1169
  w[:,j,i]=wav[:,i,j]
1078
1170
  wav=w.flatten()
1079
1171
  wwav=wwav.flatten()
1080
-
1172
+ """
1081
1173
  if not self.silent:
1082
1174
  print('Write FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
1083
1175
  np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice)
@@ -1248,10 +1340,10 @@ class FoCUS:
1248
1340
  res=v1/vh
1249
1341
  if calc_var:
1250
1342
  if self.backend.bk_is_complex(vtmp):
1251
- res2=self.backend.bk_complex(self.backend.bk_sqrt(self.backend.bk_real(v2)/self.backend.bk_real(vh)
1252
- -self.backend.bk_real(res)*self.backend.bk_real(res)), \
1253
- self.backend.bk_sqrt(self.backend.bk_imag(v2)/self.backend.bk_real(vh)
1254
- -self.backend.bk_imag(res)*self.backend.bk_imag(res)))
1343
+ res2=self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1344
+ -self.backend.bk_real(res)*self.backend.bk_real(res)) + \
1345
+ (self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
1346
+ -self.backend.bk_imag(res)*self.backend.bk_imag(res)))
1255
1347
  else:
1256
1348
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1257
1349
  return res,res2
@@ -1269,10 +1361,10 @@ class FoCUS:
1269
1361
  res=v1/vh
1270
1362
  if calc_var:
1271
1363
  if self.backend.bk_is_complex(l_x):
1272
- res2=self.backend.bk_complex(self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1273
- -self.backend.bk_real(res)*self.backend.bk_real(res))/self.backend.bk_real(v2)), \
1274
- self.backend.bk_sqrt((self.backend.bk_imag(v2)/self.backend.bk_real(vh)
1275
- -self.backend.bk_imag(res)*self.backend.bk_imag(res))/self.backend.bk_real(v2)))
1364
+ res2=self.backend.bk_sqrt(self.backend.bk_real(v2)/self.backend.bk_real(vh)
1365
+ -self.backend.bk_real(res)*self.backend.bk_real(res) + \
1366
+ self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
1367
+ -self.backend.bk_imag(res)*self.backend.bk_imag(res))
1276
1368
  else:
1277
1369
  res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1278
1370
  return res,res2
@@ -1368,7 +1460,25 @@ class FoCUS:
1368
1460
  padding=self.padding)
1369
1461
 
1370
1462
  return self.backend.bk_reshape(res,shape+[norient])
1371
-
1463
+
1464
+ def diff_data(self,x,y,is_complex=True,sigma=None):
1465
+ if sigma is None:
1466
+ if is_complex:
1467
+ r=self.backend.bk_square(self.backend.bk_real(x)-self.backend.bk_real(y))
1468
+ i=self.backend.bk_square(self.backend.bk_imag(x)-self.backend.bk_imag(y))
1469
+ return self.backend.bk_reduce_sum(r+i)
1470
+ else:
1471
+ r=self.backend.bk_square(x-y)
1472
+ return self.backend.bk_reduce_sum(r)
1473
+ else:
1474
+ if is_complex:
1475
+ r=self.backend.bk_square((self.backend.bk_real(x)-self.backend.bk_real(y))/sigma)
1476
+ i=self.backend.bk_square((self.backend.bk_imag(x)-self.backend.bk_imag(y))/sigma)
1477
+ return self.backend.bk_reduce_sum(r+i)
1478
+ else:
1479
+ r=self.backend.bk_square((x-y)/sigma)
1480
+ return self.backend.bk_reduce_sum(r)
1481
+
1372
1482
  # ---------------------------------------------−---------
1373
1483
  def convol(self,in_image,axis=0):
1374
1484
 
@@ -14,6 +14,7 @@ class GCNN:
14
14
  n_chan_out=1,
15
15
  nbatch=1,
16
16
  SEED=1234,
17
+ hidden=None,
17
18
  filename=None):
18
19
 
19
20
  if filename is not None:
@@ -29,7 +30,11 @@ class GCNN:
29
30
  self.in_nside=outlist[4]
30
31
  self.nbatch=outlist[1]
31
32
  self.n_chan_out=outlist[8]
32
-
33
+ if len(outlist[9])>0:
34
+ self.hidden=outlist[9]
35
+ else:
36
+ self.hidden=None
37
+
33
38
  self.x=self.scat_operator.backend.bk_cast(outlist[6])
34
39
  else:
35
40
  self.nscale=nscale
@@ -46,21 +51,33 @@ class GCNN:
46
51
  self.KERNELSZ= scat_operator.KERNELSZ
47
52
  self.all_type= scat_operator.all_type
48
53
  self.in_nside=in_nside
54
+ self.hidden=hidden
49
55
 
50
56
  np.random.seed(SEED)
51
57
  self.x=scat_operator.backend.bk_cast(np.random.randn(self.get_number_of_weights())/(self.KERNELSZ*self.KERNELSZ))
52
58
 
53
59
  def save(self,filename):
60
+
61
+ if self.hidden is None:
62
+ tabh=[]
63
+ else:
64
+ tabh=self.hidden
65
+
66
+ www= self.get_weights()
54
67
 
68
+ if not isinstance(www,np.ndarray):
69
+ www=www.numpy()
70
+
55
71
  outlist=[self.chanlist, \
56
72
  self.nbatch, \
57
73
  self.npar, \
58
74
  self.KERNELSZ, \
59
75
  self.in_nside, \
60
76
  self.nscale, \
61
- self.get_weights().numpy(), \
77
+ www, \
62
78
  self.all_type, \
63
- self.n_chan_out]
79
+ self.n_chan_out, \
80
+ tabh]
64
81
 
65
82
  myout=open("%s.pkl"%(filename),"wb")
66
83
  pickle.dump(outlist,myout)
@@ -68,10 +85,19 @@ class GCNN:
68
85
 
69
86
  def get_number_of_weights(self):
70
87
  totnchan=0
88
+ szk=self.KERNELSZ*self.KERNELSZ
89
+ if self.hidden is not None:
90
+ totnchan=totnchan+self.hidden[0]*self.npar
91
+ for i in range(1,len(self.hidden)):
92
+ totnchan=totnchan+self.hidden[i]*self.hidden[i-1]
93
+ totnchan=totnchan+self.hidden[len(self.hidden)-1]*12*self.in_nside**2*self.chanlist[0]
94
+ else:
95
+ totnchan=self.npar*12*self.in_nside**2*self.chanlist[0]
96
+
71
97
  for i in range(self.nscale):
72
- totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]
73
- return self.npar*12*self.in_nside**2*self.chanlist[0] \
74
- +(totnchan+self.chanlist[i+1]*self.n_chan_out)*self.KERNELSZ*self.KERNELSZ
98
+ totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]*szk
99
+
100
+ return totnchan+self.chanlist[i+1]*self.n_chan_out*szk
75
101
 
76
102
  def set_weights(self,x):
77
103
  self.x=x
@@ -83,20 +109,46 @@ class GCNN:
83
109
 
84
110
  x=self.x
85
111
 
86
- ww=self.scat_operator.backend.bk_reshape(x[0:self.npar*12*self.in_nside**2*self.chanlist[0]], \
87
- [self.npar,12*self.in_nside**2*self.chanlist[0]])
88
112
 
89
113
  if axis==0:
90
114
  nval=1
91
115
  else:
92
116
  nval=param.shape[0]
93
-
117
+
118
+ nn=0
94
119
  im=self.scat_operator.backend.bk_reshape(param,[nval,self.npar])
95
- im=self.scat_operator.backend.bk_matmul(im,ww)
96
- im=self.scat_operator.backend.bk_reshape(im,[nval,12*self.in_nside**2,self.chanlist[0]])
97
- im=self.scat_operator.backend.bk_relu(im)
120
+ if self.hidden is not None:
121
+ ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.npar*self.hidden[0]], \
122
+ [self.npar,self.hidden[0]])
123
+ im=self.scat_operator.backend.bk_matmul(im,ww)
124
+ im=self.scat_operator.backend.bk_relu(im)
125
+ nn+=self.npar*self.hidden[0]
126
+
127
+ for i in range(1,len(self.hidden)):
128
+ ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.hidden[i]*self.hidden[i-1]], \
129
+ [self.hidden[i-1],self.hidden[i]])
130
+ im=self.scat_operator.backend.bk_matmul(im,ww)
131
+ im=self.scat_operator.backend.bk_relu(im)
132
+ nn+=self.hidden[i]*self.hidden[i-1]
133
+
134
+ ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.hidden[len(self.hidden)-1]*12*self.in_nside**2*self.chanlist[0]], \
135
+ [self.hidden[len(self.hidden)-1],
136
+ 12*self.in_nside**2*self.chanlist[0]])
137
+ im=self.scat_operator.backend.bk_matmul(im,ww)
138
+ im=self.scat_operator.backend.bk_reshape(im,[nval,12*self.in_nside**2,self.chanlist[0]])
139
+ im=self.scat_operator.backend.bk_relu(im)
140
+ nn+=self.hidden[len(self.hidden)-1]*12*self.in_nside**2*self.chanlist[0]
141
+
142
+ else:
143
+ ww=self.scat_operator.backend.bk_reshape(x[0:self.npar*12*self.in_nside**2*self.chanlist[0]], \
144
+ [self.npar,12*self.in_nside**2*self.chanlist[0]])
145
+ im=self.scat_operator.backend.bk_matmul(im,ww)
146
+ im=self.scat_operator.backend.bk_reshape(im,[nval,12*self.in_nside**2,self.chanlist[0]])
147
+ im=self.scat_operator.backend.bk_relu(im)
98
148
 
99
- nn=self.npar*12*self.chanlist[0]*self.in_nside**2
149
+ nn=self.npar*12*self.chanlist[0]*self.in_nside**2
150
+
151
+
100
152
  for k in range(self.nscale):
101
153
  ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]],
102
154
  [self.KERNELSZ*self.KERNELSZ,self.chanlist[k],self.chanlist[k+1]])
@@ -349,6 +349,15 @@ class foscat_backend:
349
349
  return self.backend.reshape(x,[np.prod(np.array(list(x.shape)))])
350
350
  if self.BACKEND==self.NUMPY:
351
351
  return x.flatten()
352
+
353
+ def bk_size(self,x):
354
+ if self.BACKEND==self.TENSORFLOW:
355
+ return self.backend.size(x)
356
+ if self.BACKEND==self.TORCH:
357
+ return x.numel()
358
+
359
+ if self.BACKEND==self.NUMPY:
360
+ return x.size
352
361
 
353
362
  def bk_resize_image(self,x,shape):
354
363
  if self.BACKEND==self.TENSORFLOW:
@@ -368,11 +377,14 @@ class foscat_backend:
368
377
  xr=self.bk_real(x)
369
378
  xi=self.bk_imag(x)
370
379
 
371
- r=self.backend.sign(xr)*self.backend.sqrt(self.backend.sign(xr)*xr)
372
- i=self.backend.sign(xi)*self.backend.sqrt(self.backend.sign(xi)*xi)
373
- return self.bk_complex(r,i)
380
+ r=self.backend.sign(xr)*self.backend.sqrt(xr*xr)
381
+ i=self.backend.sign(xi)*self.backend.sqrt(xi*xi)
382
+ if self.BACKEND==self.TORCH:
383
+ return r
384
+ else:
385
+ return self.bk_complex(r,i)
374
386
  else:
375
- return self.backend.sign(x)*self.backend.sqrt(self.backend.sign(x)*x)
387
+ return self.backend.sign(x)*self.backend.sqrt(x*x)
376
388
 
377
389
  def bk_square_comp(self,x):
378
390
  if x.dtype==self.all_cbk_type:
@@ -580,6 +592,13 @@ class foscat_backend:
580
592
 
581
593
  if self.BACKEND==self.NUMPY:
582
594
  return (data.dtype=='complex64' or data.dtype=='complex128')
595
+
596
+ def bk_distcomp(self,data):
597
+ if self.bk_is_complex(data):
598
+ res=self.bk_square(self.bk_real(data))+self.bk_square(self.bk_imag(data))
599
+ return res
600
+ else:
601
+ return self.bk_square(data)
583
602
 
584
603
  def bk_norm(self,data):
585
604
  if self.bk_is_complex(data):
@@ -727,17 +746,21 @@ class foscat_backend:
727
746
  if self.BACKEND==self.TENSORFLOW:
728
747
  return self.backend.math.real(data)
729
748
  if self.BACKEND==self.TORCH:
730
- return self.backend.real(data)
749
+ return data.real
731
750
  if self.BACKEND==self.NUMPY:
732
- return self.backend.real(data)
751
+ return data.real
733
752
 
734
753
  def bk_imag(self,data):
735
754
  if self.BACKEND==self.TENSORFLOW:
736
755
  return self.backend.math.imag(data)
737
756
  if self.BACKEND==self.TORCH:
738
- return self.backend.imag(data)
757
+ if data.dtype==self.all_cbk_type:
758
+ return data.imag
759
+ else:
760
+ return 0
761
+
739
762
  if self.BACKEND==self.NUMPY:
740
- return self.backend.imag(data)
763
+ return data.imag
741
764
 
742
765
  def bk_relu(self,x):
743
766
  if self.BACKEND==self.TENSORFLOW:
@@ -32,30 +32,50 @@ class loss_backend:
32
32
  if len(x.shape)>1:
33
33
  nx=x.shape[0]
34
34
 
35
- with torch.cuda.device((operation.gpupos+self.curr_gpu)%operation.ngpu):
36
- #print('%s Run [PROC=%04d] on GPU %s'%(loss_function.name,self.mpi_rank,
37
- # operation.gpulist[(operation.gpupos+self.curr_gpu)%operation.ngpu]))
38
- #sys.stdout.flush()
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ if torch.cuda.is_available():
37
+ with torch.cuda.device((operation.gpupos+self.curr_gpu)%operation.ngpu):
39
38
 
40
- l_x=x.clone().detach().requires_grad_(True)
39
+ l_x=x.clone().detach().requires_grad_(True)
40
+
41
+ if nx==1:
42
+ ndata=x.shape[0]
43
+ else:
44
+ ndata=x.shape[0]*x.shape[1]
45
+
46
+ if KEEP_TRACK is not None:
47
+ l,linfo=loss_function.eval(l_x,batch,return_all=True)
48
+ else:
49
+ l=loss_function.eval(l_x,batch)
50
+
51
+ l.backward()
41
52
 
53
+ g=l_x.grad
54
+
55
+ self.curr_gpu=self.curr_gpu+1
56
+ else:
57
+ l_x=x.clone().detach().requires_grad_(True)
58
+
42
59
  if nx==1:
43
60
  ndata=x.shape[0]
44
61
  else:
45
62
  ndata=x.shape[0]*x.shape[1]
46
-
63
+
47
64
  if KEEP_TRACK is not None:
48
65
  l,linfo=loss_function.eval(l_x,batch,return_all=True)
49
66
  else:
67
+ """
68
+ sx=operation.eval(l_x)
69
+ tmp=(sx.C01-1.0)
70
+
71
+ l=operation.backend.bk_reduce_sum(tmp*tmp) #loss_function.eval(l_x,batch)
72
+ """
73
+
50
74
  l=loss_function.eval(l_x,batch)
51
75
 
52
76
  l.backward()
53
77
 
54
78
  g=l_x.grad
55
-
56
- print(g)
57
-
58
- self.curr_gpu=self.curr_gpu+1
59
79
 
60
80
  if KEEP_TRACK is not None:
61
81
  return l.detach(),g,linfo
@@ -1038,7 +1038,7 @@ class funct(FOC.FoCUS):
1038
1038
  return scat(mP00,mS0,mS1,mS2,mS2L,tmp.j1,tmp.j2,backend=self.backend), \
1039
1039
  scat(sP00,sS0,sS1,sS2,sS2L,tmp.j1,tmp.j2,backend=self.backend)
1040
1040
 
1041
- def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,calc_var=False):
1041
+ def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,calc_var=False,norm=None):
1042
1042
  # Check input consistency
1043
1043
  if image2 is not None:
1044
1044
  if list(image1.shape)!=list(image2.shape):
@@ -1344,6 +1344,30 @@ class funct(FOC.FoCUS):
1344
1344
  self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1345
1345
  self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1346
1346
 
1347
+ def reduce_distance(self, x,y, sigma=None):
1348
+
1349
+ if isinstance(x, scat):
1350
+ if sigma is None:
1351
+ result=self.diff_data(y.S0,x.S0,is_complex=False)
1352
+ result+=self.diff_data(y.S1,x.S1)
1353
+ result+=self.diff_data(y.P00,x.P00)
1354
+ result+=self.diff_data(y.S2,x.S2)
1355
+ result+=self.diff_data(y.S2L,x.S2L)
1356
+ else:
1357
+ result=self.diff_data(y.S0,x.S0,is_complex=False,sigma=sigma.S0)
1358
+ result+=self.diff_data(y.S1,x.S1,sigma=sigma.S1)
1359
+ result+=self.diff_data(y.P00,x.P00,sigma=sigma.P00)
1360
+ result+=self.diff_data(y.S2,x.S2,sigma=sigma.S2)
1361
+ result+=self.diff_data(y.S2L,x.S2L,sigma=sigma.S2L)
1362
+
1363
+ nval=self.backend.bk_size(x.S0)+self.backend.bk_size(x.P00)+ \
1364
+ self.backend.bk_size(x.S1)+self.backend.bk_size(x.S2)+self.backend.bk_size(x.S2L)
1365
+
1366
+ result/=self.backend.bk_cast(nval)
1367
+ else:
1368
+ return self.backend.bk_reduce_sum(x)
1369
+ return result
1370
+
1347
1371
  def reduce_mean(self,x,axis=None):
1348
1372
  if axis is None:
1349
1373
  tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))+ \
@@ -1682,6 +1682,10 @@ class funct(FOC.FoCUS):
1682
1682
  if image2 is not None:
1683
1683
  tmpi2=self.ud_grade_2(tmpi2,axis=1)
1684
1684
  return cmat,cmat2
1685
+
1686
+ def div_norm(self,complex_value,float_value):
1687
+ return self.backend.bk_complex(self.backend.bk_real(complex_value)/float_value,
1688
+ self.backend.bk_imag(complex_value)/float_value)
1685
1689
 
1686
1690
  def eval(self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False,cmat=None,cmat2=None):
1687
1691
  """
@@ -1931,7 +1935,7 @@ class funct(FOC.FoCUS):
1931
1935
  else:
1932
1936
  ### Normalize S1
1933
1937
  if norm is not None:
1934
- s1 /= (P1_dic[j3]) ** 0.5
1938
+ self.div_norm(s1,(P1_dic[j3]) ** 0.5)
1935
1939
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
1936
1940
  if S1 is None:
1937
1941
  S1 = s1[:, :, None, :] # Add a dimension for NS1
@@ -2024,7 +2028,7 @@ class funct(FOC.FoCUS):
2024
2028
  else:
2025
2029
  ### Normalize S1
2026
2030
  if norm is not None:
2027
- s1 /= (P1_dic[j3]) ** 0.5
2031
+ self.div_norm(s1,(P1_dic[j3]) ** 0.5)
2028
2032
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2029
2033
  if S1 is None:
2030
2034
  S1 = s1[:, :, None, :] # Add a dimension for NS1
@@ -2072,8 +2076,8 @@ class funct(FOC.FoCUS):
2072
2076
  else:
2073
2077
  ### Normalize C01 with P00_j [Nbatch, Nmask, Norient_j]
2074
2078
  if norm is not None:
2075
- c01 /= (P1_dic[j2][:, :, None, :] *
2076
- P1_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
2079
+ self.div_norm(c01,(P1_dic[j2][:, :, None, :] *
2080
+ P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2077
2081
 
2078
2082
  ### Store C01 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
2079
2083
  if C01 is None:
@@ -2126,10 +2130,10 @@ class funct(FOC.FoCUS):
2126
2130
  else:
2127
2131
  ### Normalize C01 and C10 with P00_j [Nbatch, Nmask, Norient_j]
2128
2132
  if norm is not None:
2129
- c01 /= (P2_dic[j2][:, :, None, :] *
2130
- P1_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
2131
- c10 /= (P1_dic[j2][:, :, None, :] *
2132
- P2_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
2133
+ self.div_norm(c01,(P2_dic[j2][:, :, None, :] *
2134
+ P1_dic[j3][:, :, :, None]) ** 0.5)# [Nbatch, Nmask, Norient3, Norient2]
2135
+ self.div_norm(c10,(P1_dic[j2][:, :, None, :] *
2136
+ P2_dic[j3][:, :, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2]
2133
2137
 
2134
2138
  ### Store C01 and C10 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
2135
2139
  if C01 is None:
@@ -2172,9 +2176,8 @@ class funct(FOC.FoCUS):
2172
2176
  else:
2173
2177
  ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
2174
2178
  if norm is not None:
2175
- c11 /= (P1_dic[j1][:, :, None, None, :] *
2176
- P1_dic[j2][:, :, None, :,
2177
- None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2179
+ self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
2180
+ P1_dic[j2][:, :, None, :,None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2178
2181
  ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
2179
2182
  if C11 is None:
2180
2183
  C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
@@ -2207,8 +2210,8 @@ class funct(FOC.FoCUS):
2207
2210
  else:
2208
2211
  ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
2209
2212
  if norm is not None:
2210
- c11 /= (P1_dic[j1][:, :, None, None, :] *
2211
- P2_dic[j2][:, :, None, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2213
+ self.div_norm(c11,(P1_dic[j1][:, :, None, None, :] *
2214
+ P2_dic[j2][:, :, None, :, None]) ** 0.5) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
2212
2215
  ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
2213
2216
  if C11 is None:
2214
2217
  C11 = c11[:, :, None, :, :, :] # Add a dimension for NC11
@@ -2394,6 +2397,39 @@ class funct(FOC.FoCUS):
2394
2397
  else:
2395
2398
  return self.backend.bk_reduce_mean(x)
2396
2399
  return result
2400
+
2401
+
2402
+ def reduce_distance(self, x,y, sigma=None):
2403
+
2404
+ if isinstance(x, scat_cov):
2405
+ if sigma is None:
2406
+ result=self.diff_data(y.S0,x.S0,is_complex=False)
2407
+ if x.S1 is not None:
2408
+ result+=self.diff_data(y.S1,x.S1)
2409
+ if x.C10 is not None:
2410
+ result+=self.diff_data(y.C10,x.C10)
2411
+ result+=self.diff_data(y.P00,x.P00)
2412
+ result+=self.diff_data(y.C01,x.C01)
2413
+ result+=self.diff_data(y.C11,x.C11)
2414
+ else:
2415
+ result=self.diff_data(y.S0,x.S0,is_complex=False,sigma=sigma.S0)
2416
+ if x.S1 is not None:
2417
+ result+=self.diff_data(y.S1,x.S1,sigma=sigma.S1)
2418
+ if x.C10 is not None:
2419
+ result+=self.diff_data(y.C10,x.C10,sigma=sigma.C10)
2420
+ result+=self.diff_data(y.P00,x.P00,sigma=sigma.P00)
2421
+ result+=self.diff_data(y.C01,x.C01,sigma=sigma.C01)
2422
+ result+=self.diff_data(y.C11,x.C11,sigma=sigma.C11)
2423
+ nval=self.backend.bk_size(x.S0)+self.backend.bk_size(x.P00)+ \
2424
+ self.backend.bk_size(x.C01)+self.backend.bk_size(x.C11)
2425
+ if x.S1 is not None:
2426
+ nval+=self.backend.bk_size(x.S1)
2427
+ if x.C10 is not None:
2428
+ nval+=self.backend.bk_size(x.C10)
2429
+ result/=self.backend.bk_cast(nval)
2430
+ else:
2431
+ return self.backend.bk_reduce_sum(x)
2432
+ return result
2397
2433
 
2398
2434
  def reduce_sum(self, x):
2399
2435
 
@@ -2412,8 +2448,7 @@ class funct(FOC.FoCUS):
2412
2448
  else:
2413
2449
  return self.backend.bk_reduce_sum(x)
2414
2450
  return result
2415
-
2416
-
2451
+
2417
2452
  def ldiff(self,sig,x):
2418
2453
 
2419
2454
  if x.S1 is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.35
3
+ Version: 3.0.40
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes