foscat 3.0.10__tar.gz → 3.0.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {foscat-3.0.10 → foscat-3.0.13}/PKG-INFO +1 -7
  2. {foscat-3.0.10 → foscat-3.0.13}/setup.py +1 -1
  3. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/FoCUS.py +90 -0
  4. foscat-3.0.13/src/foscat/GCNN.py +100 -0
  5. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/Synthesis.py +11 -6
  6. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/backend.py +35 -0
  7. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/scat_cov.py +112 -16
  8. {foscat-3.0.10 → foscat-3.0.13}/src/foscat.egg-info/PKG-INFO +1 -7
  9. {foscat-3.0.10 → foscat-3.0.13}/src/foscat.egg-info/SOURCES.txt +1 -0
  10. {foscat-3.0.10 → foscat-3.0.13}/README.md +0 -0
  11. {foscat-3.0.10 → foscat-3.0.13}/setup.cfg +0 -0
  12. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/CircSpline.py +0 -0
  13. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/GetGPUinfo.py +0 -0
  14. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/Softmax.py +0 -0
  15. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/Spline1D.py +0 -0
  16. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/__init__.py +0 -0
  17. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/loss_backend_tens.py +0 -0
  18. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/loss_backend_torch.py +0 -0
  19. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/scat.py +0 -0
  20. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/scat1D.py +0 -0
  21. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/scat2D.py +0 -0
  22. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/scat_cov1D.py +0 -0
  23. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/scat_cov2D.py +0 -0
  24. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/scat_cov_map.py +0 -0
  25. {foscat-3.0.10 → foscat-3.0.13}/src/foscat/scat_cov_map2D.py +0 -0
  26. {foscat-3.0.10 → foscat-3.0.13}/src/foscat.egg-info/dependency_links.txt +0 -0
  27. {foscat-3.0.10 → foscat-3.0.13}/src/foscat.egg-info/requires.txt +0 -0
  28. {foscat-3.0.10 → foscat-3.0.13}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.10
3
+ Version: 3.0.13
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -9,12 +9,6 @@ Maintainer: Theo Foulquier
9
9
  Maintainer-email: theo.foulquier@ifremer.fr
10
10
  License: MIT
11
11
  Keywords: Scattering transform,Component separation,denoising
12
- Requires-Dist: imageio
13
- Requires-Dist: imagecodecs
14
- Requires-Dist: matplotlib
15
- Requires-Dist: numpy
16
- Requires-Dist: tensorflow
17
- Requires-Dist: healpy
18
12
 
19
13
  Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising.
20
14
  A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO.
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
 
4
4
  setup(
5
5
  name='foscat',
6
- version='3.0.10',
6
+ version='3.0.13',
7
7
  description='Generate synthetic Healpix or 2D data using Cross Scattering Transform' ,
8
8
  long_description='Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising. \n A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO. \n Complete doc can be found at https://foscat-documentation.readthedocs.io/en/latest/index.html. \n\n List of developers : J.-M. Delouis, T. Foulquier, L. Mousset, T. Odaka, F. Paul, E. Allys ' ,
9
9
  license='MIT',
@@ -126,6 +126,8 @@ class FoCUS:
126
126
 
127
127
  self.ww_Real = {}
128
128
  self.ww_Imag = {}
129
+ self.ww_CNN_Transpose = {}
130
+ self.ww_CNN = {}
129
131
 
130
132
  wwc=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
131
133
  wws=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
@@ -271,6 +273,8 @@ class FoCUS:
271
273
  self.nest2R4[lout]=None
272
274
  self.inv_nest2R[lout]=None
273
275
  self.remove_border[lout]=None
276
+ self.ww_CNN_Transpose[lout]=None
277
+ self.ww_CNN[lout]=None
274
278
 
275
279
  self.loss={}
276
280
 
@@ -296,7 +300,92 @@ class FoCUS:
296
300
  res=x
297
301
  res[idx]=y[idx]
298
302
  return(res)
303
+
304
+ # ---------------------------------------------−---------
305
+ # make the CNN working : index reporjection of the kernel on healpix
306
+
307
+ def init_CNN_index(self,nside,transpose=False):
308
+ l_kernel=int(self.KERNELSZ*self.KERNELSZ)
309
+ weights=self.backend.bk_cast(np.ones([12*nside*nside*l_kernel],dtype='float'))
310
+ try:
311
+ if transpose:
312
+ indices=np.load('%s/FOSCAT_%s_W%d_%d_%d_CNN_Transpose.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
313
+ else:
314
+ indices=np.load('%s/FOSCAT_%s_W%d_%d_%d_CNN.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
315
+ except:
316
+ to,po=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
317
+ x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
318
+
319
+ idx=np.argsort((x-1.0)**2+y**2+z**2)[0:l_kernel]
320
+ tc,pc=hp.pix2ang(nside,idx,nest=True)
321
+
322
+ indices=np.zeros([12*nside*nside,l_kernel,2],dtype='int')
323
+ for k in range(12*nside*nside):
324
+ if k%(nside*nside)==0:
325
+ print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
326
+
327
+ rot=[po[k]/np.pi*180.0,90+(-to[k])/np.pi*180.0]
328
+ r=hp.Rotator(rot=rot).get_inverse()
329
+ # get the coordinate
330
+ ty,tx=r(tc,pc)
331
+
332
+ indices[k,:,0]=k*l_kernel+np.arange(l_kernel).astype('int')
333
+ indices[k,:,1]=hp.ang2pix(nside,ty,tx,nest=True)
334
+ if transpose:
335
+ indices[:,:,1]=indices[:,:,1]//4
336
+ np.save('%s/FOSCAT_%s_W%d_%d_%d_CNN_Transpose.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),indices)
337
+ print('Write %s/FOSCAT_%s_W%d_%d_%d_CNN_Transpose.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
338
+ else:
339
+ np.save('%s/FOSCAT_%s_W%d_%d_%d_CNNnpy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),indices)
340
+ print('Write %s/FOSCAT_%s_W%d_%d_%d_CNN.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
341
+
342
+
343
+ if transpose:
344
+ self.ww_CNN_Transpose[nside]=self.backend.bk_SparseTensor(indices.reshape(12*nside*nside*l_kernel,2),
345
+ weights,[12*nside*nside*l_kernel,
346
+ 3*nside*nside])
347
+ else:
348
+ self.ww_CNN[nside]=self.backend.bk_SparseTensor(indices.reshape(12*nside*nside*l_kernel,2),
349
+ weights,[12*nside*nside*l_kernel,
350
+ 12*nside*nside])
351
+
352
+ # ---------------------------------------------−---------
353
+ def healpix_layer_transpose(self,im,ww):
354
+ nside=2*int(np.sqrt(im.shape[0]//12))
355
+ l_kernel=self.KERNELSZ*self.KERNELSZ
356
+
357
+ if im.shape[1]!=ww.shape[1]:
358
+ print('Weights channels should be equal to the input image channels')
359
+ return -1
360
+
361
+ if self.ww_CNN_Transpose[nside] is None:
362
+ self.init_CNN_index(nside,transpose=True)
363
+
364
+ tmp=self.backend.bk_sparse_dense_matmul(self.ww_CNN_Transpose[nside],im)
365
+
366
+ density=self.backend.bk_reshape(tmp,[12*nside*nside,l_kernel*im.shape[1]])
367
+
368
+ return self.backend.bk_matmul(density,self.backend.bk_reshape(ww,[l_kernel*im.shape[1],ww.shape[2]]))
369
+
370
+ # ---------------------------------------------−---------
371
+ # ---------------------------------------------−---------
372
+ def healpix_layer(self,im,ww):
373
+ nside=int(np.sqrt(im.shape[0]//12))
374
+ l_kernel=self.KERNELSZ*self.KERNELSZ
375
+
376
+ if im.shape[1]!=ww.shape[1]:
377
+ print('Weights channels should be equal to the input image channels')
378
+ return -1
379
+
380
+ if self.ww_CNN[nside] is None:
381
+ self.init_CNN_index(nside,transpose=False)
382
+
383
+ tmp=self.backend.bk_sparse_dense_matmul(self.ww_CNN[nside],im)
384
+ density=self.backend.bk_reshape(tmp,[12*nside*nside,l_kernel*im.shape[1]])
299
385
 
386
+ return self.backend.bk_matmul(density,self.backend.bk_reshape(ww,[l_kernel*im.shape[1],ww.shape[2]]))
387
+ # ---------------------------------------------−---------
388
+
300
389
  # ---------------------------------------------−---------
301
390
  def get_rank(self):
302
391
  return(self.rank)
@@ -882,6 +971,7 @@ class FoCUS:
882
971
  return tmp
883
972
 
884
973
  return wr,wi,ws,tmp
974
+
885
975
 
886
976
  # ---------------------------------------------−---------
887
977
  # Compute x [....,a,....] to [....,a*a,....]
@@ -0,0 +1,100 @@
1
+ import numpy as np
2
+ import pickle
3
+ import foscat.scat_cov as sc
4
+
5
+
6
+ class GCNN:
7
+
8
+ def __init__(self,
9
+ scat_operator=None,
10
+ nparam=1,
11
+ nscale=1,
12
+ chanlist=[],
13
+ in_nside=1,
14
+ nbatch=1,
15
+ SEED=1234,
16
+ filename=None):
17
+
18
+ if filename is not None:
19
+
20
+ outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
21
+
22
+ self.scat_operator=sc.funct(KERNELSZ=outlist[3],all_type=outlist[7])
23
+ self.KERNELSZ= self.scat_operator.KERNELSZ
24
+ self.all_type= self.scat_operator.all_type
25
+ self.npar=outlist[2]
26
+ self.nscale=outlist[5]
27
+ self.chanlist=outlist[0]
28
+ self.in_nside=outlist[4]
29
+ self.nbatch=outlist[1]
30
+
31
+ self.x=self.scat_operator.backend.bk_cast(outlist[6])
32
+ else:
33
+ self.nscale=nscale
34
+ self.nbatch=nbatch
35
+ self.npar=nparam
36
+ self.scat_operator=scat_operator
37
+
38
+ if len(chanlist)!=nscale+1:
39
+ print('len of chanlist (here %d) should of nscale+1 (here %d)'%(len(chanlist),nscale+1))
40
+ exit(0)
41
+
42
+ self.chanlist=chanlist
43
+ self.KERNELSZ= scat_operator.KERNELSZ
44
+ self.all_type= scat_operator.all_type
45
+ self.in_nside=in_nside
46
+
47
+ np.random.seed(SEED)
48
+ self.x=scat_operator.backend.bk_cast(np.random.randn(self.get_number_of_weights())/(self.KERNELSZ*self.KERNELSZ))
49
+
50
+ def save(self,filename):
51
+
52
+ outlist=[self.chanlist, \
53
+ self.nbatch, \
54
+ self.npar, \
55
+ self.KERNELSZ, \
56
+ self.in_nside, \
57
+ self.nscale, \
58
+ self.get_weights().numpy(), \
59
+ self.all_type]
60
+ myout=open("%s.pkl"%(filename),"wb")
61
+ pickle.dump(outlist,myout)
62
+ myout.close()
63
+
64
+ def get_number_of_weights(self):
65
+ totnchan=0
66
+ for i in range(self.nscale):
67
+ totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]
68
+ return self.npar*12*self.in_nside**2*self.chanlist[0] \
69
+ +totnchan*self.KERNELSZ*self.KERNELSZ+self.chanlist[self.nscale]
70
+
71
+ def set_weights(self,x):
72
+ self.x=x
73
+
74
+ def get_weights(self):
75
+ return self.x
76
+
77
+ def eval(self,param):
78
+
79
+ x=self.x
80
+
81
+ ww=self.scat_operator.backend.bk_reshape(x[0:self.npar*12*self.in_nside**2*self.chanlist[0]], \
82
+ [self.npar,12*self.in_nside**2*self.chanlist[0]])
83
+
84
+ im=self.scat_operator.backend.bk_matmul(self.scat_operator.backend.bk_reshape(param,[1,self.npar]),ww)
85
+ im=self.scat_operator.backend.bk_reshape(im,[12*self.in_nside**2,self.chanlist[0]])
86
+ im=self.scat_operator.backend.bk_relu(im)
87
+
88
+ nn=self.npar*12*self.chanlist[0]*self.in_nside**2
89
+ for k in range(self.nscale):
90
+ ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]],
91
+ [self.KERNELSZ*self.KERNELSZ,self.chanlist[k],self.chanlist[k+1]])
92
+ nn=nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]
93
+ im=self.scat_operator.healpix_layer_transpose(im,ww)
94
+ im=self.scat_operator.backend.bk_relu(im)
95
+
96
+ ww=self.scat_operator.backend.bk_reshape(x[nn:],[self.chanlist[self.nscale],1])
97
+ im=self.scat_operator.backend.bk_matmul(im,ww)
98
+
99
+ return self.scat_operator.backend.bk_reshape(im,[im.shape[0]])
100
+
@@ -49,6 +49,7 @@ class Synthesis:
49
49
 
50
50
  self.loss_class=loss_list
51
51
  self.number_of_loss=len(loss_list)
52
+ self.__iteration__=1234
52
53
  self.nlog=0
53
54
  self.m_dw, self.v_dw = 0.0, 0.0
54
55
  self.beta1 = beta1
@@ -115,7 +116,7 @@ class Synthesis:
115
116
 
116
117
  self.nlog=self.nlog+1
117
118
  self.itt2=0
118
-
119
+
119
120
  if self.itt%self.EVAL_FREQUENCY==0 and self.mpi_rank==0:
120
121
  end = time.time()
121
122
  cur_loss='%10.3g ('%(self.ltot[self.ltot!=-1].mean())
@@ -131,6 +132,7 @@ class Synthesis:
131
132
  for k in range(info_gpu.shape[0]):
132
133
  mess=mess+'[GPU%d %.0f/%.0f MB %.0f%%]'%(k,info_gpu[k,0],info_gpu[k,1],info_gpu[k,2])
133
134
 
135
+
134
136
  print('%sItt %6d L=%s %.3fs %s'%(self.MESSAGE,self.itt,cur_loss,(end-self.start),mess))
135
137
  sys.stdout.flush()
136
138
  if self.KEEP_TRACK is not None:
@@ -140,13 +142,14 @@ class Synthesis:
140
142
  self.start = time.time()
141
143
 
142
144
  self.itt=self.itt+1
143
-
145
+
144
146
  # ---------------------------------------------−---------
145
147
  def calc_grad(self,in_x):
146
148
 
147
149
  g_tot=None
148
150
  l_tot=0.0
149
151
 
152
+
150
153
  if self.do_all_noise and self.totalsz>self.batchsz:
151
154
  nstep=self.totalsz//self.batchsz
152
155
  else:
@@ -158,6 +161,7 @@ class Synthesis:
158
161
 
159
162
  for istep in range(nstep):
160
163
 
164
+
161
165
  for k in range(self.number_of_loss):
162
166
  if self.loss_class[k].batch is None:
163
167
  l_batch=None
@@ -271,6 +275,7 @@ class Synthesis:
271
275
  self.SHOWGPU=SHOWGPU
272
276
  self.axis=axis
273
277
  self.in_x_nshape=in_x.shape[0]
278
+ self.seed=1234
274
279
 
275
280
  np.random.seed(self.mpi_rank*7+1234)
276
281
 
@@ -347,7 +352,7 @@ class Synthesis:
347
352
  start_x=x.copy()
348
353
 
349
354
  for iteration in range(NUM_STEP_BIAS):
350
-
355
+
351
356
  x,l,i=opt.fmin_l_bfgs_b(self.calc_grad,
352
357
  x.astype('float64'),
353
358
  callback=self.info_back,
@@ -357,17 +362,17 @@ class Synthesis:
357
362
 
358
363
  # update bias input data
359
364
  if iteration<NUM_STEP_BIAS-1:
360
- if self.mpi_rank==0:
361
- print('%s Hessian restart'%(self.MESSAGE))
365
+ #if self.mpi_rank==0:
366
+ # print('%s Hessian restart'%(self.MESSAGE))
362
367
 
363
368
  omap=self.xtractmap(x,axis)
364
369
 
365
370
  for k in range(self.number_of_loss):
366
371
  if self.loss_class[k].batch_update is not None:
367
372
  self.loss_class[k].batch_update(self.loss_class[k].batch_data,omap)
373
+ if self.loss_class[k].batch is not None:
368
374
  l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
369
375
  #x=start_x.copy()
370
-
371
376
 
372
377
  if self.mpi_rank==0 and SHOWGPU:
373
378
  self.stop_synthesis()
@@ -480,6 +480,41 @@ class foscat_backend:
480
480
  return(self.backend.mean(data,axis))
481
481
  if self.BACKEND==self.NUMPY:
482
482
  return(np.mean(data,axis))
483
+
484
+ def bk_reduce_min(self,data,axis=None):
485
+
486
+ if axis is None:
487
+ if self.BACKEND==self.TENSORFLOW:
488
+ return(self.backend.reduce_min(data))
489
+ if self.BACKEND==self.TORCH:
490
+ return(self.backend.min(data))
491
+ if self.BACKEND==self.NUMPY:
492
+ return(np.min(data))
493
+ else:
494
+ if self.BACKEND==self.TENSORFLOW:
495
+ return(self.backend.reduce_min(data,axis=axis))
496
+ if self.BACKEND==self.TORCH:
497
+ return(self.backend.min(data,axis))
498
+ if self.BACKEND==self.NUMPY:
499
+ return(np.min(data,axis))
500
+
501
+ def bk_random_seed(self,value):
502
+
503
+ if self.BACKEND==self.TENSORFLOW:
504
+ return(self.backend.random.set_seed(value))
505
+ if self.BACKEND==self.TORCH:
506
+ return(self.backend.random.set_seed(value))
507
+ if self.BACKEND==self.NUMPY:
508
+ return(np.random.seed(value))
509
+
510
+ def bk_random_uniform(self,shape):
511
+
512
+ if self.BACKEND==self.TENSORFLOW:
513
+ return(self.backend.random.uniform(shape))
514
+ if self.BACKEND==self.TORCH:
515
+ return(self.backend.random.uniform(shape))
516
+ if self.BACKEND==self.NUMPY:
517
+ return(np.random.rand(shape))
483
518
 
484
519
  def bk_reduce_std(self,data,axis=None):
485
520
 
@@ -1491,8 +1491,100 @@ class funct(FOC.FoCUS):
1491
1491
 
1492
1492
  return scat_cov(mS0, mP00, mC01, mC11, s1=mS1,c10=mC10,backend=self.backend), \
1493
1493
  scat_cov(sS0, sP00, sC01, sC11, s1=sS1,c10=sC10,backend=self.backend)
1494
+
1495
+ # compute local direction to make the statistical analysis more efficient
1496
+ def stat_cfft(self,im,upscale=False,smooth_scale=0):
1497
+ tmp=im
1498
+ if upscale:
1499
+ l_nside=int(np.sqrt(tmp.shape[1]//12))
1500
+ tmp=self.up_grade(tmp,l_nside*2,axis=1)
1501
+
1502
+ l_nside=int(np.sqrt(tmp.shape[1]//12))
1503
+ nscale=int(np.log(l_nside)/np.log(2))
1504
+ cmat={}
1505
+ cmat2={}
1506
+ for k in range(nscale):
1507
+ sim=self.backend.bk_abs(self.convol(tmp,axis=1))
1508
+ cc=self.backend.bk_reduce_mean(sim[:,:,0]-sim[:,:,2],0)
1509
+ ss=self.backend.bk_reduce_mean(sim[:,:,1]-sim[:,:,3],0)
1510
+ for m in range(smooth_scale):
1511
+ if cc.shape[0]>12:
1512
+ cc=self.ud_grade_2(self.smooth(cc))
1513
+ ss=self.ud_grade_2(self.smooth(ss))
1514
+ if cc.shape[0]!=tmp.shape[0]:
1515
+ ll_nside=int(np.sqrt(tmp.shape[1]//12))
1516
+ cc=self.up_grade(cc,ll_nside)
1517
+ ss=self.up_grade(ss,ll_nside)
1518
+ phase=np.fmod(np.arctan2(ss.numpy(),cc.numpy())+2*np.pi,2*np.pi)
1519
+ iph=(4*phase/(2*np.pi)).astype('int')
1520
+ alpha=(4*phase/(2*np.pi)-iph)
1521
+ mat=np.zeros([sim.shape[1],4*4])
1522
+ lidx=np.arange(sim.shape[1])
1523
+ for l in range(4):
1524
+ mat[lidx,4*((l+iph)%4)+l]=1.0-alpha
1525
+ mat[lidx,4*((l+iph+1)%4)+l]=alpha
1526
+
1527
+ cmat[k]=self.backend.bk_cast(mat.astype('complex64'))
1528
+
1529
+ mat2=np.zeros([k+1,sim.shape[1],4,4*4])
1530
+
1531
+ for k2 in range(k+1):
1532
+ tmp2=self.backend.bk_repeat(sim,4,axis=-1)
1533
+ sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat.reshape(1,mat.shape[0],16)*tmp2,
1534
+ [sim.shape[0],cmat[k].shape[0],4,4]),2)
1535
+ sim2=self.backend.bk_abs(self.convol(sim2,axis=1))
1536
+
1537
+ cc=self.smooth(self.backend.bk_reduce_mean(sim2[:,:,0]-sim2[:,:,2],0))
1538
+ ss=self.smooth(self.backend.bk_reduce_mean(sim2[:,:,1]-sim2[:,:,3],0))
1539
+ for m in range(smooth_scale):
1540
+ if cc.shape[0]>12:
1541
+ cc=self.ud_grade_2(self.smooth(cc))
1542
+ ss=self.ud_grade_2(self.smooth(ss))
1543
+ if cc.shape[0]!=sim.shape[1]:
1544
+ ll_nside=int(np.sqrt(sim.shape[1]//12))
1545
+ cc=self.up_grade(cc,ll_nside)
1546
+ ss=self.up_grade(ss,ll_nside)
1547
+
1548
+ phase=np.fmod(np.arctan2(ss.numpy(),cc.numpy())+2*np.pi,2*np.pi)
1549
+ """
1550
+ for k in range(4):
1551
+ hp.mollview(np.fmod(phase+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,1+k))
1552
+ plt.show()
1553
+ exit(0)
1554
+ """
1555
+ iph=(4*phase/(2*np.pi)).astype('int')
1556
+ alpha=(4*phase/(2*np.pi)-iph)
1557
+ lidx=np.arange(sim.shape[1])
1558
+ for m in range(4):
1559
+ for l in range(4):
1560
+ mat2[k2,lidx,m,4*((l+iph[:,m])%4)+l]=1.0-alpha[:,m]
1561
+ mat2[k2,lidx,m,4*((l+iph[:,m]+1)%4)+l]=alpha[:,m]
1562
+
1563
+ cmat2[k]=self.backend.bk_cast(mat2.astype('complex64'))
1564
+ """
1565
+ tmp=self.backend.bk_repeat(sim[0],4,axis=1)
1566
+ sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat*tmp,[12*nside**2,4,4]),1)
1567
+
1568
+ cc2=(sim2[:,0]-sim2[:,2])
1569
+ ss2=(sim2[:,1]-sim2[:,3])
1570
+ phase2=np.fmod(np.arctan2(ss2.numpy(),cc2.numpy())+2*np.pi,2*np.pi)
1571
+
1572
+ plt.figure()
1573
+ hp.mollview(phase,cmap='jet',nest=True,hold=False,sub=(2,2,1))
1574
+ hp.mollview(np.fmod(phase2+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,2))
1575
+ plt.figure()
1576
+ for k in range(4):
1577
+ hp.mollview((sim[0,:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,1+k),min=-10,max=10)
1578
+ hp.mollview((sim2[:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,5+k),min=-10,max=10)
1579
+
1580
+ plt.show()
1581
+ """
1582
+
1583
+ if k<l_nside-1:
1584
+ tmp=self.ud_grade_2(tmp,axis=1)
1585
+ return cmat,cmat2
1494
1586
 
1495
- def eval(self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False,cmat=None):
1587
+ def eval(self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False,cmat=None,cmat2=None):
1496
1588
  """
1497
1589
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
1498
1590
  mean of modulus:
@@ -1860,19 +1952,19 @@ class funct(FOC.FoCUS):
1860
1952
  ### C01_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
1861
1953
  if not cross:
1862
1954
  if calc_var:
1863
- c01,vc01 = self._compute_C01(j2,
1955
+ c01,vc01 = self._compute_C01(j2,j3,
1864
1956
  conv1,
1865
1957
  vmask,
1866
1958
  M1_dic,
1867
1959
  M1convPsi_dic,
1868
- calc_var=True) # [Nbatch, Nmask, Norient3, Norient2]
1960
+ calc_var=True,cmat2=cmat2) # [Nbatch, Nmask, Norient3, Norient2]
1869
1961
  else:
1870
- c01 = self._compute_C01(j2,
1962
+ c01 = self._compute_C01(j2,j3,
1871
1963
  conv1,
1872
1964
  vmask,
1873
1965
  M1_dic,
1874
1966
  M1convPsi_dic,
1875
- return_data=return_data) # [Nbatch, Nmask, Norient3, Norient2]
1967
+ return_data=return_data,cmat2=cmat2) # [Nbatch, Nmask, Norient3, Norient2]
1876
1968
 
1877
1969
  if return_data:
1878
1970
  if C01[j3] is None:
@@ -1900,31 +1992,31 @@ class funct(FOC.FoCUS):
1900
1992
  ### C10_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
1901
1993
  else:
1902
1994
  if calc_var:
1903
- c01,vc01 = self._compute_C01(j2,
1995
+ c01,vc01 = self._compute_C01(j2,j3,
1904
1996
  conv1,
1905
1997
  vmask,
1906
1998
  M2_dic,
1907
1999
  M2convPsi_dic,
1908
- calc_var=True)
1909
- c10,vc10 = self._compute_C01(j2,
2000
+ calc_var=True,cmat2=cmat2)
2001
+ c10,vc10 = self._compute_C01(j2,j3,
1910
2002
  conv2,
1911
2003
  vmask,
1912
2004
  M1_dic,
1913
2005
  M1convPsi_dic,
1914
- calc_var=True)
2006
+ calc_var=True,cmat2=cmat2)
1915
2007
  else:
1916
- c01 = self._compute_C01(j2,
2008
+ c01 = self._compute_C01(j2,j3,
1917
2009
  conv1,
1918
2010
  vmask,
1919
2011
  M2_dic,
1920
2012
  M2convPsi_dic,
1921
- return_data=return_data)
1922
- c10 = self._compute_C01(j2,
2013
+ return_data=return_data,cmat2=cmat2)
2014
+ c10 = self._compute_C01(j2,j3,
1923
2015
  conv2,
1924
2016
  vmask,
1925
2017
  M1_dic,
1926
2018
  M1convPsi_dic,
1927
- return_data=return_data)
2019
+ return_data=return_data,cmat2=cmat2)
1928
2020
 
1929
2021
  if return_data:
1930
2022
  if C01[j3] is None:
@@ -2085,11 +2177,12 @@ class funct(FOC.FoCUS):
2085
2177
  self.P2_dic = None
2086
2178
  return
2087
2179
 
2088
- def _compute_C01(self, j2, conv,
2180
+ def _compute_C01(self, j2, j3,conv,
2089
2181
  vmask, M_dic,
2090
2182
  MconvPsi_dic,
2091
2183
  calc_var=False,
2092
- return_data=False):
2184
+ return_data=False,
2185
+ cmat2=None):
2093
2186
  """
2094
2187
  Compute the C01 coefficients (auto or cross)
2095
2188
  C01 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
@@ -2102,7 +2195,10 @@ class funct(FOC.FoCUS):
2102
2195
  ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
2103
2196
  # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
2104
2197
  MconvPsi = self.convol(M_dic[j2], axis=1) # [Nbatch, Npix_j3, Norient3, Norient2]
2105
-
2198
+ if cmat2 is not None:
2199
+ tmp2=self.backend.bk_repeat(MconvPsi,4,axis=-1)
2200
+ MconvPsi=self.backend.bk_reduce_sum(self.backend.bk_reshape(cmat2[j3][j2]*tmp2,[1,cmat2[j3].shape[1],4,4,4]),3)
2201
+
2106
2202
  # Store it so we can use it in C11 computation
2107
2203
  MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
2108
2204
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.10
3
+ Version: 3.0.13
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -9,12 +9,6 @@ Maintainer: Theo Foulquier
9
9
  Maintainer-email: theo.foulquier@ifremer.fr
10
10
  License: MIT
11
11
  Keywords: Scattering transform,Component separation,denoising
12
- Requires-Dist: imageio
13
- Requires-Dist: imagecodecs
14
- Requires-Dist: matplotlib
15
- Requires-Dist: numpy
16
- Requires-Dist: tensorflow
17
- Requires-Dist: healpy
18
12
 
19
13
  Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising.
20
14
  A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO.
@@ -3,6 +3,7 @@ setup.cfg
3
3
  setup.py
4
4
  src/foscat/CircSpline.py
5
5
  src/foscat/FoCUS.py
6
+ src/foscat/GCNN.py
6
7
  src/foscat/GetGPUinfo.py
7
8
  src/foscat/Softmax.py
8
9
  src/foscat/Spline1D.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes