foscat 3.0.19__tar.gz → 3.0.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {foscat-3.0.19 → foscat-3.0.21}/PKG-INFO +1 -7
  2. {foscat-3.0.19 → foscat-3.0.21}/setup.py +1 -1
  3. foscat-3.0.21/src/foscat/CNN.py +110 -0
  4. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/FoCUS.py +1 -1
  5. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/Synthesis.py +1 -2
  6. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/backend.py +0 -2
  7. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/scat.py +41 -30
  8. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/scat1D.py +59 -12
  9. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/scat_cov.py +68 -41
  10. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/scat_cov1D.py +40 -27
  11. {foscat-3.0.19 → foscat-3.0.21}/src/foscat.egg-info/PKG-INFO +1 -7
  12. {foscat-3.0.19 → foscat-3.0.21}/src/foscat.egg-info/SOURCES.txt +1 -0
  13. {foscat-3.0.19 → foscat-3.0.21}/README.md +0 -0
  14. {foscat-3.0.19 → foscat-3.0.21}/setup.cfg +0 -0
  15. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/CircSpline.py +0 -0
  16. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/GCNN.py +0 -0
  17. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/GetGPUinfo.py +0 -0
  18. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/Softmax.py +0 -0
  19. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/Spline1D.py +0 -0
  20. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/__init__.py +0 -0
  21. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/backend_tens.py +0 -0
  22. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/loss_backend_tens.py +0 -0
  23. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/loss_backend_torch.py +0 -0
  24. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/scat2D.py +0 -0
  25. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/scat_cov2D.py +0 -0
  26. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/scat_cov_map.py +0 -0
  27. {foscat-3.0.19 → foscat-3.0.21}/src/foscat/scat_cov_map2D.py +0 -0
  28. {foscat-3.0.19 → foscat-3.0.21}/src/foscat.egg-info/dependency_links.txt +0 -0
  29. {foscat-3.0.19 → foscat-3.0.21}/src/foscat.egg-info/requires.txt +0 -0
  30. {foscat-3.0.19 → foscat-3.0.21}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.19
3
+ Version: 3.0.21
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -9,12 +9,6 @@ Maintainer: Theo Foulquier
9
9
  Maintainer-email: theo.foulquier@ifremer.fr
10
10
  License: MIT
11
11
  Keywords: Scattering transform,Component separation,denoising
12
- Requires-Dist: imageio
13
- Requires-Dist: imagecodecs
14
- Requires-Dist: matplotlib
15
- Requires-Dist: numpy
16
- Requires-Dist: tensorflow
17
- Requires-Dist: healpy
18
12
 
19
13
  Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising.
20
14
  A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO.
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
 
4
4
  setup(
5
5
  name='foscat',
6
- version='3.0.19',
6
+ version='3.0.21',
7
7
  description='Generate synthetic Healpix or 2D data using Cross Scattering Transform' ,
8
8
  long_description='Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising. \n A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO. \n Complete doc can be found at https://foscat-documentation.readthedocs.io/en/latest/index.html. \n\n List of developers : J.-M. Delouis, T. Foulquier, L. Mousset, T. Odaka, F. Paul, E. Allys ' ,
9
9
  license='MIT',
@@ -0,0 +1,110 @@
1
+ import numpy as np
2
+ import pickle
3
+ import foscat.scat_cov as sc
4
+
5
+
6
+ class CNN:
7
+
8
+ def __init__(self,
9
+ scat_operator=None,
10
+ nparam=1,
11
+ nscale=1,
12
+ chanlist=[],
13
+ in_nside=1,
14
+ n_chan_in=1,
15
+ nbatch=1,
16
+ SEED=1234,
17
+ filename=None):
18
+
19
+ if filename is not None:
20
+ outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
21
+ self.scat_operator=sc.funct(KERNELSZ=outlist[3],all_type=outlist[7])
22
+ self.KERNELSZ= self.scat_operator.KERNELSZ
23
+ self.all_type= self.scat_operator.all_type
24
+ self.npar=outlist[2]
25
+ self.nscale=outlist[5]
26
+ self.chanlist=outlist[0]
27
+ self.in_nside=outlist[4]
28
+ self.nbatch=outlist[1]
29
+ self.n_chan_in=outlist[8]
30
+ self.x=self.scat_operator.backend.bk_cast(outlist[6])
31
+ self.out_nside=self.in_nside//(2**self.nscale)
32
+ else:
33
+ self.nscale=nscale
34
+ self.nbatch=nbatch
35
+ self.npar=nparam
36
+ self.n_chan_in=n_chan_in
37
+ self.scat_operator=scat_operator
38
+ if len(chanlist)!=nscale+1:
39
+ print('len of chanlist (here %d) should of nscale+1 (here %d)'%(len(chanlist),nscale+1))
40
+ return None
41
+
42
+ self.chanlist=chanlist
43
+ self.KERNELSZ= scat_operator.KERNELSZ
44
+ self.all_type= scat_operator.all_type
45
+ self.in_nside=in_nside
46
+ self.out_nside=self.in_nside//(2**self.nscale)
47
+
48
+ np.random.seed(SEED)
49
+ self.x=scat_operator.backend.bk_cast(np.random.randn(self.get_number_of_weights())/(self.KERNELSZ*self.KERNELSZ))
50
+
51
+
52
+
53
+ def save(self,filename):
54
+
55
+ outlist=[self.chanlist, \
56
+ self.nbatch, \
57
+ self.npar, \
58
+ self.KERNELSZ, \
59
+ self.in_nside, \
60
+ self.nscale, \
61
+ self.get_weights().numpy(), \
62
+ self.all_type, \
63
+ self.n_chan_in]
64
+
65
+ myout=open("%s.pkl"%(filename),"wb")
66
+ pickle.dump(outlist,myout)
67
+ myout.close()
68
+
69
+ def get_number_of_weights(self):
70
+ totnchan=0
71
+ for i in range(self.nscale):
72
+ totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]
73
+ return self.npar*12*self.out_nside**2*self.chanlist[self.nscale] \
74
+ +totnchan*self.KERNELSZ*self.KERNELSZ+self.KERNELSZ*self.KERNELSZ*self.n_chan_in*self.chanlist[0]
75
+
76
+ def set_weights(self,x):
77
+ self.x=x
78
+
79
+ def get_weights(self):
80
+ return self.x
81
+
82
+ def eval(self,im):
83
+
84
+ x=self.x
85
+ ww=self.scat_operator.backend.bk_reshape(x[0:self.KERNELSZ*self.KERNELSZ*self.n_chan_in*self.chanlist[0]],
86
+ [self.KERNELSZ*self.KERNELSZ,self.n_chan_in,self.chanlist[0]])
87
+ nn=self.KERNELSZ*self.KERNELSZ*self.n_chan_in*self.chanlist[0]
88
+
89
+ im=self.scat_operator.healpix_layer(im,ww)
90
+ im=self.scat_operator.backend.bk_relu(im)
91
+
92
+ for k in range(self.nscale):
93
+ print(im.shape)
94
+ ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]],
95
+ [self.KERNELSZ*self.KERNELSZ,self.chanlist[k],self.chanlist[k+1]])
96
+ nn=nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]
97
+ im=self.scat_operator.healpix_layer(im,ww)
98
+ im=self.scat_operator.backend.bk_relu(im)
99
+ im=self.scat_operator.ud_grade_2(im,axis=0)
100
+
101
+
102
+ ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.npar*12*self.out_nside**2*self.chanlist[self.nscale]], \
103
+ [12*self.out_nside**2*self.chanlist[self.nscale],self.npar])
104
+
105
+ im=self.scat_operator.backend.bk_matmul(self.scat_operator.backend.bk_reshape(im,[1,12*self.out_nside**2*self.chanlist[self.nscale]]),ww)
106
+ im=self.scat_operator.backend.bk_reshape(im,[self.npar])
107
+ im=self.scat_operator.backend.bk_relu(im)
108
+
109
+ return im
110
+
@@ -32,7 +32,7 @@ class FoCUS:
32
32
  mpi_size=1,
33
33
  mpi_rank=0):
34
34
 
35
- self.__version__ = '3.0.19'
35
+ self.__version__ = '3.0.21'
36
36
  # P00 coeff for normalization for scat_cov
37
37
  self.TMPFILE_VERSION=TMPFILE_VERSION
38
38
  self.P1_dic = None
@@ -160,8 +160,7 @@ class Synthesis:
160
160
  self.l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
161
161
 
162
162
  for istep in range(nstep):
163
-
164
-
163
+
165
164
  for k in range(self.number_of_loss):
166
165
  if self.loss_class[k].batch is None:
167
166
  l_batch=None
@@ -347,12 +347,10 @@ class foscat_backend:
347
347
  return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
348
348
 
349
349
  if self.BACKEND==self.TORCH:
350
- print(x.shape)
351
350
  tmp=self.backend.nn.functional.interpolate(x,
352
351
  size=shape,
353
352
  mode='bilinear',
354
353
  align_corners=False)
355
- print(tmp.shape)
356
354
  return self.bk_cast(tmp)
357
355
  if self.BACKEND==self.NUMPY:
358
356
  return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
@@ -9,6 +9,7 @@ import sys
9
9
  tf_defined = 'tensorflow' in sys.modules
10
10
 
11
11
  if tf_defined:
12
+ import tensorflow as tf
12
13
  tf_function = tf.function # Facultatif : si vous voulez utiliser TensorFlow dans ce script
13
14
  else:
14
15
  def tf_function(func):
@@ -64,44 +65,55 @@ class scat:
64
65
  self.j1 , \
65
66
  self.j2 ,backend=self.backend)
66
67
 
68
+
67
69
  def domult(self,x,y):
68
- if x.dtype==y.dtype:
70
+ try:
69
71
  return x*y
70
-
71
- if self.backend.bk_is_complex(x):
72
-
73
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
74
- else:
75
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
76
-
72
+ except:
73
+ if x.dtype==y.dtype:
74
+ return x*y
75
+ if self.backend.bk_is_complex(x):
76
+
77
+ return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
78
+ else:
79
+ return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
80
+
77
81
  def dodiv(self,x,y):
78
- if x.dtype==y.dtype:
82
+ try:
79
83
  return x/y
80
- if self.backend.bk_is_complex(x):
84
+ except:
85
+ if x.dtype==y.dtype:
86
+ return x/y
87
+ if self.backend.bk_is_complex(x):
81
88
 
82
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
83
- else:
84
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
89
+ return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
90
+ else:
91
+ return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
85
92
 
86
93
  def domin(self,x,y):
87
- if x.dtype==y.dtype:
94
+ try:
88
95
  return x-y
89
-
90
- if self.backend.bk_is_complex(x):
91
-
92
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
93
- else:
94
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
96
+ except:
97
+ if x.dtype==y.dtype:
98
+ return x-y
99
+
100
+ if self.backend.bk_is_complex(x):
101
+
102
+ return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
103
+ else:
104
+ return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
95
105
 
96
106
  def doadd(self,x,y):
97
- if x.dtype==y.dtype:
107
+ try:
98
108
  return x+y
99
-
100
- if self.backend.bk_is_complex(x):
101
-
102
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
103
- else:
104
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
109
+ except:
110
+ if x.dtype==y.dtype:
111
+ return x+y
112
+ if self.backend.bk_is_complex(x):
113
+
114
+ return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
115
+ else:
116
+ return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
105
117
 
106
118
  def relu(self):
107
119
 
@@ -337,7 +349,7 @@ class scat:
337
349
  if len(tmp.shape)==4:
338
350
  for k in range(tmp.shape[3]):
339
351
  for i1 in range(tmp.shape[0]):
340
- for i2 in range(tmp.shape[0]):
352
+ for i2 in range(tmp.shape[1]):
341
353
  if test is None:
342
354
  test=1
343
355
  plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
@@ -1036,7 +1048,7 @@ class funct(FOC.FoCUS):
1036
1048
  if mask is not None:
1037
1049
  if list(image1.shape)!=list(mask.shape)[1:]:
1038
1050
  print('The mask should have the same size than the input image to eval Scattering')
1039
- print(image1.shape,mask.shape)
1051
+ print('Image shape ',image1.shape,'Mask shape ',mask.shape)
1040
1052
  exit(0)
1041
1053
  if self.use_2D and len(image1.shape)<2:
1042
1054
  print('To work with 2D scattering transform, two dimension is needed, input map has only on dimension')
@@ -1117,7 +1129,6 @@ class funct(FOC.FoCUS):
1117
1129
  else:
1118
1130
  # if the kernel size is bigger than 3 increase the binning before smoothing
1119
1131
  if self.use_2D:
1120
- print(axis,image1.shape)
1121
1132
  l_image1=self.up_grade(l_image1,I1.shape[axis]*4,axis=axis,nouty=I1.shape[axis+1]*4)
1122
1133
  vmask=self.up_grade(vmask,I1.shape[axis]*4,axis=1,nouty=I1.shape[axis+1]*4)
1123
1134
  else:
@@ -8,6 +8,7 @@ import sys
8
8
  tf_defined = 'tensorflow' in sys.modules
9
9
 
10
10
  if tf_defined:
11
+ import tensorflow as tf
11
12
  tf_function = tf.function # Facultatif : si vous voulez utiliser TensorFlow dans ce script
12
13
  else:
13
14
  def tf_function(func):
@@ -28,7 +29,41 @@ class scat1D:
28
29
  self.j2=j2
29
30
  self.cross=cross
30
31
  self.backend=backend
32
+
33
+ # ---------------------------------------------−---------
34
+ def build_flat(self,table):
35
+ shape=table.shape
36
+ ndata=1
37
+ for k in range(1,len(table.shape)):
38
+ ndata=ndata*table.shape[k]
39
+ return self.backend.bk_reshape(table,[table.shape[0],ndata])
31
40
 
41
+ # ---------------------------------------------−---------
42
+ def flatten(self,S2L=False):
43
+ if not S2L:
44
+ if isinstance(self.P00,np.ndarray):
45
+ return np.concatenate([self.build_flat(lf.S0),
46
+ self.build_flat(lf.S1),
47
+ self.build_flat(lf.P00),
48
+ self.build_flat(lf.S2)],1)
49
+ else:
50
+ return self.backend.bk_concat([self.build_flat(self.S0),
51
+ self.build_flat(self.S1),
52
+ self.build_flat(self.P00),
53
+ self.build_flat(self.S2)],1)
54
+ else:
55
+ if isinstance(self.P00,np.ndarray):
56
+ return np.concatenate([self.build_flat(lf.S0),
57
+ self.build_flat(lf.S1),
58
+ self.build_flat(lf.P00),
59
+ self.build_flat(lf.S2),
60
+ self.build_flat(lf.S2L)],1)
61
+ else:
62
+ return self.backend.bk_concat([self.build_flat(self.S0),
63
+ self.build_flat(self.S1),
64
+ self.build_flat(self.P00),
65
+ self.build_flat(self.S2),
66
+ self.build_flat(self.S2L)],1)
32
67
  def get_j_idx(self):
33
68
  return self.j1,self.j2
34
69
 
@@ -62,6 +97,7 @@ class scat1D:
62
97
  def domult(self,x,y):
63
98
  if x.dtype==y.dtype:
64
99
  return x*y
100
+
65
101
  if self.backend.bk_is_complex(x):
66
102
 
67
103
  return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
@@ -80,6 +116,7 @@ class scat1D:
80
116
  def domin(self,x,y):
81
117
  if x.dtype==y.dtype:
82
118
  return x-y
119
+
83
120
  if self.backend.bk_is_complex(x):
84
121
 
85
122
  return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
@@ -89,6 +126,7 @@ class scat1D:
89
126
  def doadd(self,x,y):
90
127
  if x.dtype==y.dtype:
91
128
  return x+y
129
+
92
130
  if self.backend.bk_is_complex(x):
93
131
 
94
132
  return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
@@ -572,8 +610,6 @@ class scat1D:
572
610
  s2=self.S2.numpy()
573
611
  s2l=self.S2L.numpy()
574
612
 
575
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
576
-
577
613
  if isinstance(threshold,scat1D):
578
614
  if isinstance(threshold.S1,np.ndarray):
579
615
  s1th=threshold.S1
@@ -663,7 +699,6 @@ class scat1D:
663
699
  s2l[:,i0]=s2l[:,i1]
664
700
  else:
665
701
  idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
666
- print(i0,i2)
667
702
  if len(idx[0])>0:
668
703
  s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
669
704
  idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
@@ -681,7 +716,6 @@ class scat1D:
681
716
  p0[np.isnan(p0)]=0.0
682
717
  s2[np.isnan(s2)]=0.0
683
718
  s2l[np.isnan(s2l)]=0.0
684
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
685
719
 
686
720
  return scat1D(self.backend.constant(p0),self.S0,
687
721
  self.backend.constant(s1),
@@ -815,9 +849,14 @@ class funct(FOC.FoCUS):
815
849
  def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,Add_R45=False,axis=0):
816
850
  # Check input consistency
817
851
  if mask is not None:
818
- if list(image1.shape)!=list(mask.shape)[1:]:
819
- print('The mask should have the same size than the input timeline to eval Scattering')
820
- exit(0)
852
+ if len(image1.shape)==1:
853
+ if image1.shape[0]!=mask.shape[1]:
854
+ print('The mask should have the same size than the input timeline to eval Scattering')
855
+ exit(0)
856
+ else:
857
+ if image1.shape[1]!=mask.shape[1]:
858
+ print('The mask should have the same size than the input timeline to eval Scattering')
859
+ exit(0)
821
860
 
822
861
  ### AUTO OR CROSS
823
862
  cross = False
@@ -830,7 +869,7 @@ class funct(FOC.FoCUS):
830
869
  # determine jmax and nside corresponding to the input map
831
870
  im_shape = image1.shape
832
871
 
833
- nside=im_shape[axis]
872
+ nside=im_shape[len(image1.shape)-1]
834
873
  npix=nside
835
874
 
836
875
  jmax=int(np.log(nside)/np.log(2)) #-self.OSTEP
@@ -865,11 +904,19 @@ class funct(FOC.FoCUS):
865
904
  l_image1=I1
866
905
  if cross:
867
906
  l_image2=I2
868
-
869
- s0 = self.backend.bk_reduce_sum(l_image1*vmask,axis=axis+1)+s0_off
907
+ if len(image1.shape)==1:
908
+ s0 = self.backend.bk_reduce_sum(l_image1*vmask,axis=axis+1)
909
+ if cross and Auto==False:
910
+ s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(l_image2*vmask,axis=axis)])
911
+ else:
912
+ lmask=self.backend.bk_expand_dims(vmask,0)
913
+ lim=self.backend.bk_expand_dims(l_image1,1)
914
+ s0 = self.backend.bk_reduce_sum(lim*lmask,axis=axis+2)
915
+ if cross and Auto==False:
916
+ lim=self.backend.bk_expand_dims(l_image2,1)
917
+ s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(lim*lmask,axis=axis+2)])
918
+
870
919
 
871
- if cross and Auto==False:
872
- s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(l_image2*vmask,axis=axis)+s0_off])
873
920
 
874
921
  s1=None
875
922
  s2=None
@@ -9,6 +9,7 @@ import sys
9
9
  tf_defined = 'tensorflow' in sys.modules
10
10
 
11
11
  if tf_defined:
12
+ import tensorflow as tf
12
13
  tf_function = tf.function # Facultatif : si vous voulez utiliser TensorFlow dans ce script
13
14
  else:
14
15
  def tf_function(func):
@@ -576,41 +577,53 @@ class scat_cov:
576
577
  s1=s1, c10=c10,backend=self.backend)
577
578
 
578
579
  def domult(self,x,y):
579
- if x.dtype==y.dtype:
580
+ try:
580
581
  return x*y
581
- if self.backend.bk_is_complex(x):
582
-
583
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
584
- else:
585
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
582
+ except:
583
+ if x.dtype==y.dtype:
584
+ return x*y
585
+ if self.backend.bk_is_complex(x):
586
+
587
+ return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
588
+ else:
589
+ return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
586
590
 
587
591
  def dodiv(self,x,y):
588
- if x.dtype==y.dtype:
592
+ try:
589
593
  return x/y
590
- if self.backend.bk_is_complex(x):
594
+ except:
595
+ if x.dtype==y.dtype:
596
+ return x/y
597
+ if self.backend.bk_is_complex(x):
591
598
 
592
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
593
- else:
594
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
599
+ return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
600
+ else:
601
+ return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
595
602
 
596
603
  def domin(self,x,y):
597
- if x.dtype==y.dtype:
604
+ try:
598
605
  return x-y
599
-
600
- if self.backend.bk_is_complex(x):
601
-
602
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
603
- else:
604
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
606
+ except:
607
+ if x.dtype==y.dtype:
608
+ return x-y
609
+
610
+ if self.backend.bk_is_complex(x):
611
+
612
+ return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
613
+ else:
614
+ return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
605
615
 
606
616
  def doadd(self,x,y):
607
- if x.dtype==y.dtype:
617
+ try:
608
618
  return x+y
609
- if self.backend.bk_is_complex(x):
610
-
611
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
612
- else:
613
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
619
+ except:
620
+ if x.dtype==y.dtype:
621
+ return x+y
622
+ if self.backend.bk_is_complex(x):
623
+
624
+ return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
625
+ else:
626
+ return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
614
627
 
615
628
 
616
629
  def __mul__(self, other):
@@ -735,7 +748,7 @@ class scat_cov:
735
748
  return scat_cov(self.S0,self.backend.constant(p0),self.backend.constant(c01),
736
749
  self.backend.constant(c11),s1=s1,c10=c10,backend=self.backend)
737
750
 
738
- def plot(self, name=None, hold=True, color='blue', lw=1, legend=True):
751
+ def plot(self, name=None, hold=True, color='blue', lw=1, legend=True,norm=False):
739
752
 
740
753
  import matplotlib.pyplot as plt
741
754
 
@@ -754,7 +767,7 @@ class scat_cov:
754
767
  if len(tmp.shape)>3:
755
768
  for k in range(tmp.shape[3]):
756
769
  for i1 in range(tmp.shape[0]):
757
- for i2 in range(tmp.shape[0]):
770
+ for i2 in range(tmp.shape[1]):
758
771
  if test is None:
759
772
  test=1
760
773
  plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $S_1$' % (name), lw=lw)
@@ -762,7 +775,7 @@ class scat_cov:
762
775
  plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
763
776
  else:
764
777
  for i1 in range(tmp.shape[0]):
765
- for i2 in range(tmp.shape[0]):
778
+ for i2 in range(tmp.shape[1]):
766
779
  if test is None:
767
780
  test=1
768
781
  plt.plot(tmp[i1,i2,:],color=color, label=r'%s $S_1$' % (name), lw=lw)
@@ -776,10 +789,11 @@ class scat_cov:
776
789
  test=None
777
790
  plt.subplot(2, 2, 2)
778
791
  tmp=abs(self.get_np(self.P00))
792
+ ntmp=np.sqrt(tmp)
779
793
  if len(tmp.shape)>3:
780
794
  for k in range(tmp.shape[3]):
781
795
  for i1 in range(tmp.shape[0]):
782
- for i2 in range(tmp.shape[0]):
796
+ for i2 in range(tmp.shape[1]):
783
797
  if test is None:
784
798
  test=1
785
799
  plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
@@ -787,7 +801,7 @@ class scat_cov:
787
801
  plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
788
802
  else:
789
803
  for i1 in range(tmp.shape[0]):
790
- for i2 in range(tmp.shape[0]):
804
+ for i2 in range(tmp.shape[1]):
791
805
  if test is None:
792
806
  test=1
793
807
  plt.plot(tmp[i1,i2,:],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
@@ -819,15 +833,18 @@ class scat_cov:
819
833
  for i2 in range(j1.max()+1):
820
834
  for i3 in range(tmp.shape[3]):
821
835
  for i4 in range(tmp.shape[4]):
836
+ dtmp=tmp[i0,i1,j1==i2,i3,i4]
837
+ if norm:
838
+ dtmp=dtmp/ntmp[i0,i1,i2,i3]
822
839
  if j2[j1==i2].shape[0]==1:
823
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4],'.', \
840
+ ax1.plot(j2[j1==i2]+n,dtmp,'.', \
824
841
  color=color, lw=lw)
825
842
  else:
826
843
  if legend and test is None:
827
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
844
+ ax1.plot(j2[j1==i2]+n,dtmp, \
828
845
  color=color, label=lname, lw=lw)
829
846
  test=1
830
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
847
+ ax1.plot(j2[j1==i2]+n,dtmp, \
831
848
  color=color, lw=lw)
832
849
  tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
833
850
  tabx=tabx+[k+n for k in j2[j1==i2]]
@@ -840,15 +857,18 @@ class scat_cov:
840
857
  for i1 in range(tmp.shape[1]):
841
858
  for i2 in range(j1.max()+1):
842
859
  for i3 in range(tmp.shape[3]):
860
+ dtmp=tmp[i0,i1,j1==i2,i3]
861
+ if norm:
862
+ dtmp=dtmp/ntmp[i0,i1,i2]
843
863
  if j2[j1==i2].shape[0]==1:
844
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3],'.', \
864
+ ax1.plot(j2[j1==i2]+n,dtmp,'.', \
845
865
  color=color, lw=lw)
846
866
  else:
847
867
  if legend and test is None:
848
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3], \
868
+ ax1.plot(j2[j1==i2]+n,dtmp, \
849
869
  color=color, label=lname, lw=lw)
850
870
  test=1
851
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3], \
871
+ ax1.plot(j2[j1==i2]+n,dtmp, \
852
872
  color=color, lw=lw)
853
873
  tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
854
874
  tabx=tabx+[k+n for k in j2[j1==i2]]
@@ -894,6 +914,7 @@ class scat_cov:
894
914
  tabnx=[]
895
915
  tab2x=[]
896
916
  tab2nx=[]
917
+ ntmp=ntmp*ntmp
897
918
  if len(tmp.shape)>4:
898
919
  for i0 in range(tmp.shape[0]):
899
920
  for i1 in range(tmp.shape[1]):
@@ -904,15 +925,18 @@ class scat_cov:
904
925
  for i3 in range(tmp.shape[3]):
905
926
  for i4 in range(tmp.shape[4]):
906
927
  for i5 in range(tmp.shape[5]):
928
+ dtmp=tmp[i0,i1,idx,i3,i4,i5]
929
+ if norm:
930
+ dtmp=dtmp/ntmp[i0,i1,i2,i3]
907
931
  if len(idx)==1:
908
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx,i3,i4,i5],'.', \
932
+ ax1.plot(np.arange(len(idx))+n,dtmp,'.', \
909
933
  color=color, lw=lw)
910
934
  else:
911
935
  if legend and test is None:
912
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx,i3,i4,i5], \
936
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
913
937
  color=color, label=lname, lw=lw)
914
938
  test=1
915
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx,i3,i4,i5], \
939
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
916
940
  color=color, lw=lw)
917
941
  tabnx=tabnx+[r'%d,%d'%(j2[k],j3[k]) for k in idx]
918
942
  tabx=tabx+[k+n for k in range(len(idx))]
@@ -928,15 +952,18 @@ class scat_cov:
928
952
  for i2b in range(j2[j1==i2].max()+1):
929
953
  idx=np.where((j1==i2)*(j2==i2b))[0]
930
954
  for i3 in range(tmp.shape[3]):
955
+ dtmp=tmp[i0,i1,idx,i3]
956
+ if norm:
957
+ dtmp=dtmp/ntmp[i0,i1,i2]
931
958
  if len(idx)==1:
932
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx,i3],'.', \
959
+ ax1.plot(np.arange(len(idx))+n,dtmp,'.', \
933
960
  color=color, lw=lw)
934
961
  else:
935
962
  if legend and test is None:
936
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx,i3], \
963
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
937
964
  color=color, label=lname, lw=lw)
938
965
  test=1
939
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx,i3], \
966
+ ax1.plot(np.arange(len(idx))+n,dtmp, \
940
967
  color=color, lw=lw)
941
968
  tabnx=tabnx+[r'%d,%d'%(j2[k],j3[k]) for k in idx]
942
969
  tabx=tabx+[k+n for k in range(len(idx))]
@@ -10,6 +10,7 @@ import sys
10
10
  tf_defined = 'tensorflow' in sys.modules
11
11
 
12
12
  if tf_defined:
13
+ import tensorflow as tf
13
14
  tf_function = tf.function # Facultatif : si vous voulez utiliser TensorFlow dans ce script
14
15
  else:
15
16
  def tf_function(func):
@@ -374,42 +375,54 @@ class scat_cov1D:
374
375
  (self.C01 - other),
375
376
  c11,
376
377
  s1=s1, c10=c10,backend=self.backend)
377
-
378
- def domult(self,x,y):
379
- if x.dtype==y.dtype:
378
+ def domult(self,x,y):
379
+ try:
380
380
  return x*y
381
-
382
- if self.backend.bk_is_complex(x):
383
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
384
- else:
385
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
386
-
381
+ except:
382
+ if x.dtype==y.dtype:
383
+ return x*y
384
+ if self.backend.bk_is_complex(x):
385
+
386
+ return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
387
+ else:
388
+ return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
389
+
387
390
  def dodiv(self,x,y):
388
- if x.dtype==y.dtype:
391
+ try:
389
392
  return x/y
390
-
391
- if self.backend.bk_is_complex(x):
392
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
393
- else:
394
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
393
+ except:
394
+ if x.dtype==y.dtype:
395
+ return x/y
396
+ if self.backend.bk_is_complex(x):
397
+
398
+ return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
399
+ else:
400
+ return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
395
401
 
396
402
  def domin(self,x,y):
397
- if x.dtype==y.dtype:
403
+ try:
398
404
  return x-y
399
-
400
- if self.backend.bk_is_complex(x):
401
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
402
- else:
403
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
405
+ except:
406
+ if x.dtype==y.dtype:
407
+ return x-y
408
+
409
+ if self.backend.bk_is_complex(x):
410
+
411
+ return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
412
+ else:
413
+ return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
404
414
 
405
415
  def doadd(self,x,y):
406
- if x.dtype==y.dtype:
416
+ try:
407
417
  return x+y
408
-
409
- if self.backend.bk_is_complex(x):
410
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
411
- else:
412
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
418
+ except:
419
+ if x.dtype==y.dtype:
420
+ return x+y
421
+ if self.backend.bk_is_complex(x):
422
+
423
+ return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
424
+ else:
425
+ return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
413
426
 
414
427
 
415
428
  def __mul__(self, other):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.19
3
+ Version: 3.0.21
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -9,12 +9,6 @@ Maintainer: Theo Foulquier
9
9
  Maintainer-email: theo.foulquier@ifremer.fr
10
10
  License: MIT
11
11
  Keywords: Scattering transform,Component separation,denoising
12
- Requires-Dist: imageio
13
- Requires-Dist: imagecodecs
14
- Requires-Dist: matplotlib
15
- Requires-Dist: numpy
16
- Requires-Dist: tensorflow
17
- Requires-Dist: healpy
18
12
 
19
13
  Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising.
20
14
  A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO.
@@ -1,6 +1,7 @@
1
1
  README.md
2
2
  setup.cfg
3
3
  setup.py
4
+ src/foscat/CNN.py
4
5
  src/foscat/CircSpline.py
5
6
  src/foscat/FoCUS.py
6
7
  src/foscat/GCNN.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes