foscat 3.0.9__tar.gz → 3.0.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {foscat-3.0.9 → foscat-3.0.14}/PKG-INFO +1 -1
  2. {foscat-3.0.9 → foscat-3.0.14}/setup.py +1 -1
  3. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/FoCUS.py +140 -27
  4. foscat-3.0.14/src/foscat/GCNN.py +100 -0
  5. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/Synthesis.py +11 -6
  6. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/backend.py +41 -1
  7. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov.py +135 -20
  8. {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/PKG-INFO +1 -1
  9. {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/SOURCES.txt +1 -0
  10. {foscat-3.0.9 → foscat-3.0.14}/README.md +0 -0
  11. {foscat-3.0.9 → foscat-3.0.14}/setup.cfg +0 -0
  12. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/CircSpline.py +0 -0
  13. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/GetGPUinfo.py +0 -0
  14. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/Softmax.py +0 -0
  15. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/Spline1D.py +0 -0
  16. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/__init__.py +0 -0
  17. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/loss_backend_tens.py +0 -0
  18. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/loss_backend_torch.py +0 -0
  19. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat.py +0 -0
  20. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat1D.py +0 -0
  21. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat2D.py +0 -0
  22. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov1D.py +0 -0
  23. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov2D.py +0 -0
  24. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov_map.py +0 -0
  25. {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov_map2D.py +0 -0
  26. {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/dependency_links.txt +0 -0
  27. {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/requires.txt +0 -0
  28. {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.9
3
+ Version: 3.0.14
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
 
4
4
  setup(
5
5
  name='foscat',
6
- version='3.0.9',
6
+ version='3.0.14',
7
7
  description='Generate synthetic Healpix or 2D data using Cross Scattering Transform' ,
8
8
  long_description='Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising. \n A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO. \n Complete doc can be found at https://foscat-documentation.readthedocs.io/en/latest/index.html. \n\n List of developers : J.-M. Delouis, T. Foulquier, L. Mousset, T. Odaka, F. Paul, E. Allys ' ,
9
9
  license='MIT',
@@ -27,6 +27,7 @@ class FoCUS:
27
27
  JmaxDelta=0,
28
28
  DODIV=False,
29
29
  InitWave=None,
30
+ silent=False,
30
31
  mpi_size=1,
31
32
  mpi_rank=0):
32
33
 
@@ -42,20 +43,25 @@ class FoCUS:
42
43
  self.mpi_size=mpi_size
43
44
  self.mpi_rank=mpi_rank
44
45
  self.return_data=return_data
45
-
46
- print('================================================')
47
- print(' START FOSCAT CONFIGURATION')
48
- print('================================================')
49
- sys.stdout.flush()
46
+ self.silent=silent
47
+
48
+ if not silent:
49
+ print('================================================')
50
+ print(' START FOSCAT CONFIGURATION')
51
+ print('================================================')
52
+ sys.stdout.flush()
50
53
 
51
54
  self.TEMPLATE_PATH=TEMPLATE_PATH
52
55
  if os.path.exists(self.TEMPLATE_PATH)==False:
53
- print('The directory %s to store temporary information for FoCUS does not exist: Try to create it'%(self.TEMPLATE_PATH))
56
+ if not silent:
57
+ print('The directory %s to store temporary information for FoCUS does not exist: Try to create it'%(self.TEMPLATE_PATH))
54
58
  try:
55
59
  os.system('mkdir -p %s'%(self.TEMPLATE_PATH))
56
- print('The directory %s is created')
60
+ if not silent:
61
+ print('The directory %s is created')
57
62
  except:
58
- print('Impossible to create the directory %s'%(self.TEMPLATE_PATH))
63
+ if not silent:
64
+ print('Impossible to create the directory %s'%(self.TEMPLATE_PATH))
59
65
  exit(0)
60
66
 
61
67
  self.number_of_loss=0
@@ -65,13 +71,15 @@ class FoCUS:
65
71
  self.padding=padding
66
72
 
67
73
  if OSTEP!=0:
68
- print('OPTION option is deprecated after version 2.0.6. Please use Jmax option')
74
+ if not silent:
75
+ print('OPTION option is deprecated after version 2.0.6. Please use Jmax option')
69
76
  JmaxDelta=OSTEP
70
77
  else:
71
78
  OSTEP=JmaxDelta
72
79
 
73
80
  if JmaxDelta<-1:
74
- print('Warning : Jmax can not be smaller than -1')
81
+ if not silent:
82
+ print('Warning : Jmax can not be smaller than -1')
75
83
  exit(0)
76
84
 
77
85
  self.OSTEP=JmaxDelta
@@ -103,14 +111,15 @@ class FoCUS:
103
111
 
104
112
  self.gpupos=(gpupos+mpi_rank)%self.backend.ngpu
105
113
 
106
- print('============================================================')
107
- print('== ==')
108
- print('== ==')
109
- print('== RUN ON GPU Rank %d : %s =='%(mpi_rank,self.gpulist[self.gpupos%self.ngpu]))
110
- print('== ==')
111
- print('== ==')
112
- print('============================================================')
113
- sys.stdout.flush()
114
+ if not silent:
115
+ print('============================================================')
116
+ print('== ==')
117
+ print('== ==')
118
+ print('== RUN ON GPU Rank %d : %s =='%(mpi_rank,self.gpulist[self.gpupos%self.ngpu]))
119
+ print('== ==')
120
+ print('== ==')
121
+ print('============================================================')
122
+ sys.stdout.flush()
114
123
 
115
124
  l_NORIENT=NORIENT
116
125
  if DODIV:
@@ -126,6 +135,8 @@ class FoCUS:
126
135
 
127
136
  self.ww_Real = {}
128
137
  self.ww_Imag = {}
138
+ self.ww_CNN_Transpose = {}
139
+ self.ww_CNN = {}
129
140
 
130
141
  wwc=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
131
142
  wws=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
@@ -210,7 +221,8 @@ class FoCUS:
210
221
 
211
222
  for i in range(1,6):
212
223
  lout=(2**i)
213
- print('Init Wave ',lout)
224
+ if not silent:
225
+ print('Init Wave ',lout)
214
226
 
215
227
  if self.InitWave is None:
216
228
  wr,wi,ws,widx=self.init_index(lout)
@@ -271,6 +283,8 @@ class FoCUS:
271
283
  self.nest2R4[lout]=None
272
284
  self.inv_nest2R[lout]=None
273
285
  self.remove_border[lout]=None
286
+ self.ww_CNN_Transpose[lout]=None
287
+ self.ww_CNN[lout]=None
274
288
 
275
289
  self.loss={}
276
290
 
@@ -296,7 +310,97 @@ class FoCUS:
296
310
  res=x
297
311
  res[idx]=y[idx]
298
312
  return(res)
313
+
314
+ # ---------------------------------------------−---------
315
+ # make the CNN working : index reporjection of the kernel on healpix
316
+
317
+ def init_CNN_index(self,nside,transpose=False):
318
+ l_kernel=int(self.KERNELSZ*self.KERNELSZ)
319
+ weights=self.backend.bk_cast(np.ones([12*nside*nside*l_kernel],dtype='float'))
320
+ try:
321
+ if transpose:
322
+ indices=np.load('%s/FOSCAT_%s_W%d_%d_%d_CNN_Transpose.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
323
+ else:
324
+ indices=np.load('%s/FOSCAT_%s_W%d_%d_%d_CNN.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
325
+ except:
326
+ to,po=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
327
+ x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
328
+
329
+ idx=np.argsort((x-1.0)**2+y**2+z**2)[0:l_kernel]
330
+ tc,pc=hp.pix2ang(nside,idx,nest=True)
331
+
332
+ indices=np.zeros([12*nside*nside,l_kernel,2],dtype='int')
333
+ for k in range(12*nside*nside):
334
+ if k%(nside*nside)==0:
335
+ if not silent:
336
+ print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
337
+
338
+ rot=[po[k]/np.pi*180.0,90+(-to[k])/np.pi*180.0]
339
+ r=hp.Rotator(rot=rot).get_inverse()
340
+ # get the coordinate
341
+ ty,tx=r(tc,pc)
342
+
343
+ indices[k,:,0]=k*l_kernel+np.arange(l_kernel).astype('int')
344
+ indices[k,:,1]=hp.ang2pix(nside,ty,tx,nest=True)
345
+ if transpose:
346
+ indices[:,:,1]=indices[:,:,1]//4
347
+ np.save('%s/FOSCAT_%s_W%d_%d_%d_CNN_Transpose.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),indices)
348
+ if not silent:
349
+ print('Write %s/FOSCAT_%s_W%d_%d_%d_CNN_Transpose.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
350
+ else:
351
+ np.save('%s/FOSCAT_%s_W%d_%d_%d_CNNnpy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),indices)
352
+ if not silent:
353
+ print('Write %s/FOSCAT_%s_W%d_%d_%d_CNN.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
354
+
355
+
356
+ if transpose:
357
+ self.ww_CNN_Transpose[nside]=self.backend.bk_SparseTensor(indices.reshape(12*nside*nside*l_kernel,2),
358
+ weights,[12*nside*nside*l_kernel,
359
+ 3*nside*nside])
360
+ else:
361
+ self.ww_CNN[nside]=self.backend.bk_SparseTensor(indices.reshape(12*nside*nside*l_kernel,2),
362
+ weights,[12*nside*nside*l_kernel,
363
+ 12*nside*nside])
364
+
365
+ # ---------------------------------------------−---------
366
+ def healpix_layer_transpose(self,im,ww):
367
+ nside=2*int(np.sqrt(im.shape[0]//12))
368
+ l_kernel=self.KERNELSZ*self.KERNELSZ
369
+
370
+ if im.shape[1]!=ww.shape[1]:
371
+ if not self.silent:
372
+ print('Weights channels should be equal to the input image channels')
373
+ return -1
374
+
375
+ if self.ww_CNN_Transpose[nside] is None:
376
+ self.init_CNN_index(nside,transpose=True)
377
+
378
+ tmp=self.backend.bk_sparse_dense_matmul(self.ww_CNN_Transpose[nside],im)
379
+
380
+ density=self.backend.bk_reshape(tmp,[12*nside*nside,l_kernel*im.shape[1]])
381
+
382
+ return self.backend.bk_matmul(density,self.backend.bk_reshape(ww,[l_kernel*im.shape[1],ww.shape[2]]))
383
+
384
+ # ---------------------------------------------−---------
385
+ # ---------------------------------------------−---------
386
+ def healpix_layer(self,im,ww):
387
+ nside=int(np.sqrt(im.shape[0]//12))
388
+ l_kernel=self.KERNELSZ*self.KERNELSZ
389
+
390
+ if im.shape[1]!=ww.shape[1]:
391
+ if not self.silent:
392
+ print('Weights channels should be equal to the input image channels')
393
+ return -1
394
+
395
+ if self.ww_CNN[nside] is None:
396
+ self.init_CNN_index(nside,transpose=False)
397
+
398
+ tmp=self.backend.bk_sparse_dense_matmul(self.ww_CNN[nside],im)
399
+ density=self.backend.bk_reshape(tmp,[12*nside*nside,l_kernel*im.shape[1]])
299
400
 
401
+ return self.backend.bk_matmul(density,self.backend.bk_reshape(ww,[l_kernel*im.shape[1],ww.shape[2]]))
402
+ # ---------------------------------------------−---------
403
+
300
404
  # ---------------------------------------------−---------
301
405
  def get_rank(self):
302
406
  return(self.rank)
@@ -332,7 +436,8 @@ class FoCUS:
332
436
  if self.use_2D:
333
437
  ishape=list(im.shape)
334
438
  if len(ishape)<axis+2:
335
- print('Use of 2D scat with data that has less than 2D')
439
+ if not self.silent:
440
+ print('Use of 2D scat with data that has less than 2D')
336
441
  exit(0)
337
442
 
338
443
  npix=im.shape[axis]
@@ -392,7 +497,8 @@ class FoCUS:
392
497
  if self.use_2D:
393
498
  ishape=list(im.shape)
394
499
  if len(ishape)<axis+2:
395
- print('Use of 2D scat with data that has less than 2D')
500
+ if not self.silent:
501
+ print('Use of 2D scat with data that has less than 2D')
396
502
  exit(0)
397
503
 
398
504
  if nouty is None:
@@ -434,7 +540,8 @@ class FoCUS:
434
540
  lout=int(np.sqrt(im.shape[axis]//12))
435
541
 
436
542
  if self.pix_interp_val[lout][nout] is None:
437
- print('compute lout nout',lout,nout)
543
+ if not self.silent:
544
+ print('compute lout nout',lout,nout)
438
545
  th,ph=hp.pix2ang(nout,np.arange(12*nout**2,dtype='int'),nest=True)
439
546
  p, w = hp.get_interp_weights(lout,th,ph,nest=True)
440
547
  del th
@@ -791,7 +898,8 @@ class FoCUS:
791
898
 
792
899
  for k in range(12*nside*nside):
793
900
  if k%(nside*nside)==0:
794
- print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
901
+ if not self.silent:
902
+ print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
795
903
  if nside>scale*2:
796
904
  lidx=hp.get_all_neighbours(nside//scale,th[k//(scale*scale)],ph[k//(scale*scale)],nest=True)
797
905
  lidx=np.concatenate([lidx,np.array([(k//(scale*scale))])],0)
@@ -843,7 +951,8 @@ class FoCUS:
843
951
  wav=w.flatten()
844
952
  wwav=wwav.flatten()
845
953
 
846
- print('Write FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
954
+ if not self.silent:
955
+ print('Write FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
847
956
  np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice)
848
957
  np.save('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),wav)
849
958
  np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice2)
@@ -857,7 +966,8 @@ class FoCUS:
857
966
  self.comp_idx_w25(nside)
858
967
  else:
859
968
  if self.rank==0:
860
- print('Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d'%(KERNELSZ,KERNELSZ))
969
+ if not self.silent:
970
+ print('Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d'%(KERNELSZ,KERNELSZ))
861
971
  exit(0)
862
972
 
863
973
  self.barrier()
@@ -882,6 +992,7 @@ class FoCUS:
882
992
  return tmp
883
993
 
884
994
  return wr,wi,ws,tmp
995
+
885
996
 
886
997
  # ---------------------------------------------−---------
887
998
  # Compute x [....,a,....] to [....,a*a,....]
@@ -1140,7 +1251,8 @@ class FoCUS:
1140
1251
 
1141
1252
  ishape=list(in_image.shape)
1142
1253
  if len(ishape)<axis+2:
1143
- print('Use of 2D scat with data that has less than 2D')
1254
+ if not self.silent:
1255
+ print('Use of 2D scat with data that has less than 2D')
1144
1256
  exit(0)
1145
1257
 
1146
1258
  npix=ishape[axis]
@@ -1264,7 +1376,8 @@ class FoCUS:
1264
1376
 
1265
1377
  ishape=list(in_image.shape)
1266
1378
  if len(ishape)<axis+2:
1267
- print('Use of 2D scat with data that has less than 2D')
1379
+ if not self.silent:
1380
+ print('Use of 2D scat with data that has less than 2D')
1268
1381
  exit(0)
1269
1382
 
1270
1383
  npix=ishape[axis]
@@ -0,0 +1,100 @@
1
+ import numpy as np
2
+ import pickle
3
+ import foscat.scat_cov as sc
4
+
5
+
6
+ class GCNN:
7
+
8
+ def __init__(self,
9
+ scat_operator=None,
10
+ nparam=1,
11
+ nscale=1,
12
+ chanlist=[],
13
+ in_nside=1,
14
+ nbatch=1,
15
+ SEED=1234,
16
+ filename=None):
17
+
18
+ if filename is not None:
19
+
20
+ outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
21
+
22
+ self.scat_operator=sc.funct(KERNELSZ=outlist[3],all_type=outlist[7])
23
+ self.KERNELSZ= self.scat_operator.KERNELSZ
24
+ self.all_type= self.scat_operator.all_type
25
+ self.npar=outlist[2]
26
+ self.nscale=outlist[5]
27
+ self.chanlist=outlist[0]
28
+ self.in_nside=outlist[4]
29
+ self.nbatch=outlist[1]
30
+
31
+ self.x=self.scat_operator.backend.bk_cast(outlist[6])
32
+ else:
33
+ self.nscale=nscale
34
+ self.nbatch=nbatch
35
+ self.npar=nparam
36
+ self.scat_operator=scat_operator
37
+
38
+ if len(chanlist)!=nscale+1:
39
+ print('len of chanlist (here %d) should of nscale+1 (here %d)'%(len(chanlist),nscale+1))
40
+ exit(0)
41
+
42
+ self.chanlist=chanlist
43
+ self.KERNELSZ= scat_operator.KERNELSZ
44
+ self.all_type= scat_operator.all_type
45
+ self.in_nside=in_nside
46
+
47
+ np.random.seed(SEED)
48
+ self.x=scat_operator.backend.bk_cast(np.random.randn(self.get_number_of_weights())/(self.KERNELSZ*self.KERNELSZ))
49
+
50
+ def save(self,filename):
51
+
52
+ outlist=[self.chanlist, \
53
+ self.nbatch, \
54
+ self.npar, \
55
+ self.KERNELSZ, \
56
+ self.in_nside, \
57
+ self.nscale, \
58
+ self.get_weights().numpy(), \
59
+ self.all_type]
60
+ myout=open("%s.pkl"%(filename),"wb")
61
+ pickle.dump(outlist,myout)
62
+ myout.close()
63
+
64
+ def get_number_of_weights(self):
65
+ totnchan=0
66
+ for i in range(self.nscale):
67
+ totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]
68
+ return self.npar*12*self.in_nside**2*self.chanlist[0] \
69
+ +totnchan*self.KERNELSZ*self.KERNELSZ+self.chanlist[self.nscale]
70
+
71
+ def set_weights(self,x):
72
+ self.x=x
73
+
74
+ def get_weights(self):
75
+ return self.x
76
+
77
+ def eval(self,param):
78
+
79
+ x=self.x
80
+
81
+ ww=self.scat_operator.backend.bk_reshape(x[0:self.npar*12*self.in_nside**2*self.chanlist[0]], \
82
+ [self.npar,12*self.in_nside**2*self.chanlist[0]])
83
+
84
+ im=self.scat_operator.backend.bk_matmul(self.scat_operator.backend.bk_reshape(param,[1,self.npar]),ww)
85
+ im=self.scat_operator.backend.bk_reshape(im,[12*self.in_nside**2,self.chanlist[0]])
86
+ im=self.scat_operator.backend.bk_relu(im)
87
+
88
+ nn=self.npar*12*self.chanlist[0]*self.in_nside**2
89
+ for k in range(self.nscale):
90
+ ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]],
91
+ [self.KERNELSZ*self.KERNELSZ,self.chanlist[k],self.chanlist[k+1]])
92
+ nn=nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]
93
+ im=self.scat_operator.healpix_layer_transpose(im,ww)
94
+ im=self.scat_operator.backend.bk_relu(im)
95
+
96
+ ww=self.scat_operator.backend.bk_reshape(x[nn:],[self.chanlist[self.nscale],1])
97
+ im=self.scat_operator.backend.bk_matmul(im,ww)
98
+
99
+ return self.scat_operator.backend.bk_reshape(im,[im.shape[0]])
100
+
@@ -49,6 +49,7 @@ class Synthesis:
49
49
 
50
50
  self.loss_class=loss_list
51
51
  self.number_of_loss=len(loss_list)
52
+ self.__iteration__=1234
52
53
  self.nlog=0
53
54
  self.m_dw, self.v_dw = 0.0, 0.0
54
55
  self.beta1 = beta1
@@ -115,7 +116,7 @@ class Synthesis:
115
116
 
116
117
  self.nlog=self.nlog+1
117
118
  self.itt2=0
118
-
119
+
119
120
  if self.itt%self.EVAL_FREQUENCY==0 and self.mpi_rank==0:
120
121
  end = time.time()
121
122
  cur_loss='%10.3g ('%(self.ltot[self.ltot!=-1].mean())
@@ -131,6 +132,7 @@ class Synthesis:
131
132
  for k in range(info_gpu.shape[0]):
132
133
  mess=mess+'[GPU%d %.0f/%.0f MB %.0f%%]'%(k,info_gpu[k,0],info_gpu[k,1],info_gpu[k,2])
133
134
 
135
+
134
136
  print('%sItt %6d L=%s %.3fs %s'%(self.MESSAGE,self.itt,cur_loss,(end-self.start),mess))
135
137
  sys.stdout.flush()
136
138
  if self.KEEP_TRACK is not None:
@@ -140,13 +142,14 @@ class Synthesis:
140
142
  self.start = time.time()
141
143
 
142
144
  self.itt=self.itt+1
143
-
145
+
144
146
  # ---------------------------------------------−---------
145
147
  def calc_grad(self,in_x):
146
148
 
147
149
  g_tot=None
148
150
  l_tot=0.0
149
151
 
152
+
150
153
  if self.do_all_noise and self.totalsz>self.batchsz:
151
154
  nstep=self.totalsz//self.batchsz
152
155
  else:
@@ -158,6 +161,7 @@ class Synthesis:
158
161
 
159
162
  for istep in range(nstep):
160
163
 
164
+
161
165
  for k in range(self.number_of_loss):
162
166
  if self.loss_class[k].batch is None:
163
167
  l_batch=None
@@ -271,6 +275,7 @@ class Synthesis:
271
275
  self.SHOWGPU=SHOWGPU
272
276
  self.axis=axis
273
277
  self.in_x_nshape=in_x.shape[0]
278
+ self.seed=1234
274
279
 
275
280
  np.random.seed(self.mpi_rank*7+1234)
276
281
 
@@ -347,7 +352,7 @@ class Synthesis:
347
352
  start_x=x.copy()
348
353
 
349
354
  for iteration in range(NUM_STEP_BIAS):
350
-
355
+
351
356
  x,l,i=opt.fmin_l_bfgs_b(self.calc_grad,
352
357
  x.astype('float64'),
353
358
  callback=self.info_back,
@@ -357,17 +362,17 @@ class Synthesis:
357
362
 
358
363
  # update bias input data
359
364
  if iteration<NUM_STEP_BIAS-1:
360
- if self.mpi_rank==0:
361
- print('%s Hessian restart'%(self.MESSAGE))
365
+ #if self.mpi_rank==0:
366
+ # print('%s Hessian restart'%(self.MESSAGE))
362
367
 
363
368
  omap=self.xtractmap(x,axis)
364
369
 
365
370
  for k in range(self.number_of_loss):
366
371
  if self.loss_class[k].batch_update is not None:
367
372
  self.loss_class[k].batch_update(self.loss_class[k].batch_data,omap)
373
+ if self.loss_class[k].batch is not None:
368
374
  l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
369
375
  #x=start_x.copy()
370
-
371
376
 
372
377
  if self.mpi_rank==0 and SHOWGPU:
373
378
  self.stop_synthesis()
@@ -297,9 +297,14 @@ class foscat_backend:
297
297
  if self.BACKEND==self.TENSORFLOW:
298
298
  return self.backend.nn.conv1d(x,w, stride=[1,1,1], padding='SAME')
299
299
  if self.BACKEND==self.TORCH:
300
+ # Torch not yet done !!!
300
301
  return self.backend.nn.conv1d(x,w, stride=1, padding='SAME')
301
302
  if self.BACKEND==self.NUMPY:
302
- return self.backend.nn.conv1d(x,w, stride=1, padding='SAME')
303
+ res=np.zeros([x.shape[0],x.shape[1],w.shape[1]],dtype=x.dtype)
304
+ for k in range(w.shape[1]):
305
+ for l in range(w.shape[2]):
306
+ res[:,:,l]+=self.scipy.ndimage.convolve1d(x[:,:,k],w[:,k,l],axis=1)
307
+ return res
303
308
 
304
309
  def bk_flattenR(self,x):
305
310
  if self.BACKEND==self.TENSORFLOW or self.BACKEND==self.TORCH:
@@ -480,6 +485,41 @@ class foscat_backend:
480
485
  return(self.backend.mean(data,axis))
481
486
  if self.BACKEND==self.NUMPY:
482
487
  return(np.mean(data,axis))
488
+
489
+ def bk_reduce_min(self,data,axis=None):
490
+
491
+ if axis is None:
492
+ if self.BACKEND==self.TENSORFLOW:
493
+ return(self.backend.reduce_min(data))
494
+ if self.BACKEND==self.TORCH:
495
+ return(self.backend.min(data))
496
+ if self.BACKEND==self.NUMPY:
497
+ return(np.min(data))
498
+ else:
499
+ if self.BACKEND==self.TENSORFLOW:
500
+ return(self.backend.reduce_min(data,axis=axis))
501
+ if self.BACKEND==self.TORCH:
502
+ return(self.backend.min(data,axis))
503
+ if self.BACKEND==self.NUMPY:
504
+ return(np.min(data,axis))
505
+
506
+ def bk_random_seed(self,value):
507
+
508
+ if self.BACKEND==self.TENSORFLOW:
509
+ return(self.backend.random.set_seed(value))
510
+ if self.BACKEND==self.TORCH:
511
+ return(self.backend.random.set_seed(value))
512
+ if self.BACKEND==self.NUMPY:
513
+ return(np.random.seed(value))
514
+
515
+ def bk_random_uniform(self,shape):
516
+
517
+ if self.BACKEND==self.TENSORFLOW:
518
+ return(self.backend.random.uniform(shape))
519
+ if self.BACKEND==self.TORCH:
520
+ return(self.backend.random.uniform(shape))
521
+ if self.BACKEND==self.NUMPY:
522
+ return(np.random.rand(shape))
483
523
 
484
524
  def bk_reduce_std(self,data,axis=None):
485
525
 
@@ -1491,8 +1491,111 @@ class funct(FOC.FoCUS):
1491
1491
 
1492
1492
  return scat_cov(mS0, mP00, mC01, mC11, s1=mS1,c10=mC10,backend=self.backend), \
1493
1493
  scat_cov(sS0, sP00, sC01, sC11, s1=sS1,c10=sC10,backend=self.backend)
1494
+
1495
+ # compute local direction to make the statistical analysis more efficient
1496
+ def stat_cfft(self,im,image2=None,upscale=False,smooth_scale=0):
1497
+ tmp=im
1498
+ if image2 is not None:
1499
+ tmpi2=image2
1500
+ if upscale:
1501
+ l_nside=int(np.sqrt(tmp.shape[1]//12))
1502
+ tmp=self.up_grade(tmp,l_nside*2,axis=1)
1503
+ if image2 is not None:
1504
+ tmpi2=self.up_grade(tmpi2,l_nside*2,axis=1)
1505
+
1506
+ l_nside=int(np.sqrt(tmp.shape[1]//12))
1507
+ nscale=int(np.log(l_nside)/np.log(2))
1508
+ cmat={}
1509
+ cmat2={}
1510
+ for k in range(nscale):
1511
+ sim=self.backend.bk_abs(self.convol(tmp,axis=1))
1512
+ if image2 is not None:
1513
+ sim=self.backend.bk_real(self.backend.bk_L1(self.convol(tmp,axis=1)*self.backend.bk_conjugate(self.convol(tmpi2,axis=1))))
1514
+ else:
1515
+ sim=self.backend.bk_abs(self.convol(tmp,axis=1))
1516
+
1517
+ cc=self.backend.bk_reduce_mean(sim[:,:,0]-sim[:,:,2],0)
1518
+ ss=self.backend.bk_reduce_mean(sim[:,:,1]-sim[:,:,3],0)
1519
+ for m in range(smooth_scale):
1520
+ if cc.shape[0]>12:
1521
+ cc=self.ud_grade_2(self.smooth(cc))
1522
+ ss=self.ud_grade_2(self.smooth(ss))
1523
+ if cc.shape[0]!=tmp.shape[0]:
1524
+ ll_nside=int(np.sqrt(tmp.shape[1]//12))
1525
+ cc=self.up_grade(cc,ll_nside)
1526
+ ss=self.up_grade(ss,ll_nside)
1527
+ phase=np.fmod(np.arctan2(ss.numpy(),cc.numpy())+2*np.pi,2*np.pi)
1528
+ iph=(4*phase/(2*np.pi)).astype('int')
1529
+ alpha=(4*phase/(2*np.pi)-iph)
1530
+ mat=np.zeros([sim.shape[1],4*4])
1531
+ lidx=np.arange(sim.shape[1])
1532
+ for l in range(4):
1533
+ mat[lidx,4*((l+iph)%4)+l]=1.0-alpha
1534
+ mat[lidx,4*((l+iph+1)%4)+l]=alpha
1535
+
1536
+ cmat[k]=self.backend.bk_cast(mat.astype('complex64'))
1537
+
1538
+ mat2=np.zeros([k+1,sim.shape[1],4,4*4])
1539
+
1540
+ for k2 in range(k+1):
1541
+ tmp2=self.backend.bk_repeat(sim,4,axis=-1)
1542
+ sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat.reshape(1,mat.shape[0],16)*tmp2,
1543
+ [sim.shape[0],cmat[k].shape[0],4,4]),2)
1544
+ sim2=self.backend.bk_abs(self.convol(sim2,axis=1))
1545
+
1546
+ cc=self.smooth(self.backend.bk_reduce_mean(sim2[:,:,0]-sim2[:,:,2],0))
1547
+ ss=self.smooth(self.backend.bk_reduce_mean(sim2[:,:,1]-sim2[:,:,3],0))
1548
+ for m in range(smooth_scale):
1549
+ if cc.shape[0]>12:
1550
+ cc=self.ud_grade_2(self.smooth(cc))
1551
+ ss=self.ud_grade_2(self.smooth(ss))
1552
+ if cc.shape[0]!=sim.shape[1]:
1553
+ ll_nside=int(np.sqrt(sim.shape[1]//12))
1554
+ cc=self.up_grade(cc,ll_nside)
1555
+ ss=self.up_grade(ss,ll_nside)
1556
+
1557
+ phase=np.fmod(np.arctan2(ss.numpy(),cc.numpy())+2*np.pi,2*np.pi)
1558
+ """
1559
+ for k in range(4):
1560
+ hp.mollview(np.fmod(phase+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,1+k))
1561
+ plt.show()
1562
+ exit(0)
1563
+ """
1564
+ iph=(4*phase/(2*np.pi)).astype('int')
1565
+ alpha=(4*phase/(2*np.pi)-iph)
1566
+ lidx=np.arange(sim.shape[1])
1567
+ for m in range(4):
1568
+ for l in range(4):
1569
+ mat2[k2,lidx,m,4*((l+iph[:,m])%4)+l]=1.0-alpha[:,m]
1570
+ mat2[k2,lidx,m,4*((l+iph[:,m]+1)%4)+l]=alpha[:,m]
1571
+
1572
+ cmat2[k]=self.backend.bk_cast(mat2.astype('complex64'))
1573
+ """
1574
+ tmp=self.backend.bk_repeat(sim[0],4,axis=1)
1575
+ sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat*tmp,[12*nside**2,4,4]),1)
1576
+
1577
+ cc2=(sim2[:,0]-sim2[:,2])
1578
+ ss2=(sim2[:,1]-sim2[:,3])
1579
+ phase2=np.fmod(np.arctan2(ss2.numpy(),cc2.numpy())+2*np.pi,2*np.pi)
1580
+
1581
+ plt.figure()
1582
+ hp.mollview(phase,cmap='jet',nest=True,hold=False,sub=(2,2,1))
1583
+ hp.mollview(np.fmod(phase2+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,2))
1584
+ plt.figure()
1585
+ for k in range(4):
1586
+ hp.mollview((sim[0,:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,1+k),min=-10,max=10)
1587
+ hp.mollview((sim2[:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,5+k),min=-10,max=10)
1588
+
1589
+ plt.show()
1590
+ """
1591
+
1592
+ if k<l_nside-1:
1593
+ tmp=self.ud_grade_2(tmp,axis=1)
1594
+ if image2 is not None:
1595
+ tmpi2=self.ud_grade_2(tmpi2,axis=1)
1596
+ return cmat,cmat2
1494
1597
 
1495
- def eval(self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False):
1598
+ def eval(self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False,cmat=None,cmat2=None):
1496
1599
  """
1497
1600
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
1498
1601
  mean of modulus:
@@ -1677,6 +1780,11 @@ class funct(FOC.FoCUS):
1677
1780
  ####### S1 and P00
1678
1781
  ### Make the convolution I1 * Psi_j3
1679
1782
  conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
1783
+
1784
+ if cmat is not None:
1785
+ tmp2=self.backend.bk_repeat(conv1,4,axis=-1)
1786
+ conv1=self.backend.bk_reduce_sum(self.backend.bk_reshape(cmat[j3]*tmp2,[1,cmat[j3].shape[0],4,4]),2)
1787
+
1680
1788
  ### Take the module M1 = |I1 * Psi_j3|
1681
1789
  M1_square = conv1*self.backend.bk_conjugate(conv1) # [Nbatch, Npix_j3, Norient3]
1682
1790
  M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
@@ -1749,6 +1857,9 @@ class funct(FOC.FoCUS):
1749
1857
  else: # Cross
1750
1858
  ### Make the convolution I2 * Psi_j3
1751
1859
  conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
1860
+ if cmat is not None:
1861
+ tmp2=self.backend.bk_repeat(conv2,4,axis=-1)
1862
+ conv2=self.backend.bk_reduce_sum(self.backend.bk_reshape(cmat[j3]*tmp2,[1,cmat[j3].shape[0],4,4]),2)
1752
1863
  ### Take the module M2 = |I2 * Psi_j3|
1753
1864
  M2_square = conv2*self.backend.bk_conjugate(conv2) # [Nbatch, Npix_j3, Norient3]
1754
1865
  M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
@@ -1852,19 +1963,19 @@ class funct(FOC.FoCUS):
1852
1963
  ### C01_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
1853
1964
  if not cross:
1854
1965
  if calc_var:
1855
- c01,vc01 = self._compute_C01(j2,
1966
+ c01,vc01 = self._compute_C01(j2,j3,
1856
1967
  conv1,
1857
1968
  vmask,
1858
1969
  M1_dic,
1859
1970
  M1convPsi_dic,
1860
- calc_var=True) # [Nbatch, Nmask, Norient3, Norient2]
1971
+ calc_var=True,cmat2=cmat2) # [Nbatch, Nmask, Norient3, Norient2]
1861
1972
  else:
1862
- c01 = self._compute_C01(j2,
1973
+ c01 = self._compute_C01(j2,j3,
1863
1974
  conv1,
1864
1975
  vmask,
1865
1976
  M1_dic,
1866
1977
  M1convPsi_dic,
1867
- return_data=return_data) # [Nbatch, Nmask, Norient3, Norient2]
1978
+ return_data=return_data,cmat2=cmat2) # [Nbatch, Nmask, Norient3, Norient2]
1868
1979
 
1869
1980
  if return_data:
1870
1981
  if C01[j3] is None:
@@ -1892,31 +2003,31 @@ class funct(FOC.FoCUS):
1892
2003
  ### C10_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
1893
2004
  else:
1894
2005
  if calc_var:
1895
- c01,vc01 = self._compute_C01(j2,
2006
+ c01,vc01 = self._compute_C01(j2,j3,
1896
2007
  conv1,
1897
2008
  vmask,
1898
2009
  M2_dic,
1899
2010
  M2convPsi_dic,
1900
- calc_var=True)
1901
- c10,vc10 = self._compute_C01(j2,
2011
+ calc_var=True,cmat2=cmat2)
2012
+ c10,vc10 = self._compute_C01(j2,j3,
1902
2013
  conv2,
1903
2014
  vmask,
1904
2015
  M1_dic,
1905
2016
  M1convPsi_dic,
1906
- calc_var=True)
2017
+ calc_var=True,cmat2=cmat2)
1907
2018
  else:
1908
- c01 = self._compute_C01(j2,
2019
+ c01 = self._compute_C01(j2,j3,
1909
2020
  conv1,
1910
2021
  vmask,
1911
2022
  M2_dic,
1912
2023
  M2convPsi_dic,
1913
- return_data=return_data)
1914
- c10 = self._compute_C01(j2,
2024
+ return_data=return_data,cmat2=cmat2)
2025
+ c10 = self._compute_C01(j2,j3,
1915
2026
  conv2,
1916
2027
  vmask,
1917
2028
  M1_dic,
1918
2029
  M1convPsi_dic,
1919
- return_data=return_data)
2030
+ return_data=return_data,cmat2=cmat2)
1920
2031
 
1921
2032
  if return_data:
1922
2033
  if C01[j3] is None:
@@ -2077,11 +2188,12 @@ class funct(FOC.FoCUS):
2077
2188
  self.P2_dic = None
2078
2189
  return
2079
2190
 
2080
- def _compute_C01(self, j2, conv,
2191
+ def _compute_C01(self, j2, j3,conv,
2081
2192
  vmask, M_dic,
2082
2193
  MconvPsi_dic,
2083
2194
  calc_var=False,
2084
- return_data=False):
2195
+ return_data=False,
2196
+ cmat2=None):
2085
2197
  """
2086
2198
  Compute the C01 coefficients (auto or cross)
2087
2199
  C01 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
@@ -2094,7 +2206,10 @@ class funct(FOC.FoCUS):
2094
2206
  ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
2095
2207
  # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
2096
2208
  MconvPsi = self.convol(M_dic[j2], axis=1) # [Nbatch, Npix_j3, Norient3, Norient2]
2097
-
2209
+ if cmat2 is not None:
2210
+ tmp2=self.backend.bk_repeat(MconvPsi,4,axis=-1)
2211
+ MconvPsi=self.backend.bk_reduce_sum(self.backend.bk_reshape(cmat2[j3][j2]*tmp2,[1,cmat2[j3].shape[1],4,4,4]),3)
2212
+
2098
2213
  # Store it so we can use it in C11 computation
2099
2214
  MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
2100
2215
 
@@ -2306,13 +2421,13 @@ class funct(FOC.FoCUS):
2306
2421
  """
2307
2422
  @tf.function
2308
2423
  """
2309
- def eval_comp_fast(self, image1, image2=None,mask=None,norm=None, Auto=True):
2424
+ def eval_comp_fast(self, image1, image2=None,mask=None,norm=None, Auto=True,cmat=None,cmat2=None):
2310
2425
 
2311
- res=self.eval(image1, image2=image2,mask=mask,Auto=Auto)
2426
+ res=self.eval(image1, image2=image2,mask=mask,Auto=Auto,cmat=cmat,cmat2=cmat2)
2312
2427
  return res.S0,res.P00,res.S1,res.C01,res.C11,res.C10
2313
2428
 
2314
- def eval_fast(self, image1, image2=None,mask=None,norm=None, Auto=True):
2315
- s0,p0,s1,c01,c11,c10=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto)
2429
+ def eval_fast(self, image1, image2=None,mask=None,norm=None, Auto=True,cmat=None,cmat2=None):
2430
+ s0,p0,s1,c01,c11,c10=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,cmat=cmat,cmat2=cmat2)
2316
2431
  return scat_cov(s0, p0, c01, c11, s1=s1,c10=c10,backend=self.backend)
2317
2432
 
2318
2433
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.9
3
+ Version: 3.0.14
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -3,6 +3,7 @@ setup.cfg
3
3
  setup.py
4
4
  src/foscat/CircSpline.py
5
5
  src/foscat/FoCUS.py
6
+ src/foscat/GCNN.py
6
7
  src/foscat/GetGPUinfo.py
7
8
  src/foscat/Softmax.py
8
9
  src/foscat/Spline1D.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes