foscat 3.0.20__py3-none-any.whl → 3.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/CNN.py ADDED
@@ -0,0 +1,110 @@
1
+ import numpy as np
2
+ import pickle
3
+ import foscat.scat_cov as sc
4
+
5
+
6
+ class CNN:
7
+
8
+ def __init__(self,
9
+ scat_operator=None,
10
+ nparam=1,
11
+ nscale=1,
12
+ chanlist=[],
13
+ in_nside=1,
14
+ n_chan_in=1,
15
+ nbatch=1,
16
+ SEED=1234,
17
+ filename=None):
18
+
19
+ if filename is not None:
20
+ outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
21
+ self.scat_operator=sc.funct(KERNELSZ=outlist[3],all_type=outlist[7])
22
+ self.KERNELSZ= self.scat_operator.KERNELSZ
23
+ self.all_type= self.scat_operator.all_type
24
+ self.npar=outlist[2]
25
+ self.nscale=outlist[5]
26
+ self.chanlist=outlist[0]
27
+ self.in_nside=outlist[4]
28
+ self.nbatch=outlist[1]
29
+ self.n_chan_in=outlist[8]
30
+ self.x=self.scat_operator.backend.bk_cast(outlist[6])
31
+ self.out_nside=self.in_nside//(2**self.nscale)
32
+ else:
33
+ self.nscale=nscale
34
+ self.nbatch=nbatch
35
+ self.npar=nparam
36
+ self.n_chan_in=n_chan_in
37
+ self.scat_operator=scat_operator
38
+ if len(chanlist)!=nscale+1:
39
+ print('len of chanlist (here %d) should of nscale+1 (here %d)'%(len(chanlist),nscale+1))
40
+ return None
41
+
42
+ self.chanlist=chanlist
43
+ self.KERNELSZ= scat_operator.KERNELSZ
44
+ self.all_type= scat_operator.all_type
45
+ self.in_nside=in_nside
46
+ self.out_nside=self.in_nside//(2**self.nscale)
47
+
48
+ np.random.seed(SEED)
49
+ self.x=scat_operator.backend.bk_cast(np.random.randn(self.get_number_of_weights())/(self.KERNELSZ*self.KERNELSZ))
50
+
51
+
52
+
53
+ def save(self,filename):
54
+
55
+ outlist=[self.chanlist, \
56
+ self.nbatch, \
57
+ self.npar, \
58
+ self.KERNELSZ, \
59
+ self.in_nside, \
60
+ self.nscale, \
61
+ self.get_weights().numpy(), \
62
+ self.all_type, \
63
+ self.n_chan_in]
64
+
65
+ myout=open("%s.pkl"%(filename),"wb")
66
+ pickle.dump(outlist,myout)
67
+ myout.close()
68
+
69
+ def get_number_of_weights(self):
70
+ totnchan=0
71
+ for i in range(self.nscale):
72
+ totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]
73
+ return self.npar*12*self.out_nside**2*self.chanlist[self.nscale] \
74
+ +totnchan*self.KERNELSZ*self.KERNELSZ+self.KERNELSZ*self.KERNELSZ*self.n_chan_in*self.chanlist[0]
75
+
76
+ def set_weights(self,x):
77
+ self.x=x
78
+
79
+ def get_weights(self):
80
+ return self.x
81
+
82
+ def eval(self,im):
83
+
84
+ x=self.x
85
+ ww=self.scat_operator.backend.bk_reshape(x[0:self.KERNELSZ*self.KERNELSZ*self.n_chan_in*self.chanlist[0]],
86
+ [self.KERNELSZ*self.KERNELSZ,self.n_chan_in,self.chanlist[0]])
87
+ nn=self.KERNELSZ*self.KERNELSZ*self.n_chan_in*self.chanlist[0]
88
+
89
+ im=self.scat_operator.healpix_layer(im,ww)
90
+ im=self.scat_operator.backend.bk_relu(im)
91
+
92
+ for k in range(self.nscale):
93
+ print(im.shape)
94
+ ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]],
95
+ [self.KERNELSZ*self.KERNELSZ,self.chanlist[k],self.chanlist[k+1]])
96
+ nn=nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]
97
+ im=self.scat_operator.healpix_layer(im,ww)
98
+ im=self.scat_operator.backend.bk_relu(im)
99
+ im=self.scat_operator.ud_grade_2(im,axis=0)
100
+
101
+
102
+ ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.npar*12*self.out_nside**2*self.chanlist[self.nscale]], \
103
+ [12*self.out_nside**2*self.chanlist[self.nscale],self.npar])
104
+
105
+ im=self.scat_operator.backend.bk_matmul(self.scat_operator.backend.bk_reshape(im,[1,12*self.out_nside**2*self.chanlist[self.nscale]]),ww)
106
+ im=self.scat_operator.backend.bk_reshape(im,[self.npar])
107
+ im=self.scat_operator.backend.bk_relu(im)
108
+
109
+ return im
110
+
foscat/FoCUS.py CHANGED
@@ -32,7 +32,7 @@ class FoCUS:
32
32
  mpi_size=1,
33
33
  mpi_rank=0):
34
34
 
35
- self.__version__ = '3.0.20'
35
+ self.__version__ = '3.0.21'
36
36
  # P00 coeff for normalization for scat_cov
37
37
  self.TMPFILE_VERSION=TMPFILE_VERSION
38
38
  self.P1_dic = None
foscat/backend.py CHANGED
@@ -347,12 +347,10 @@ class foscat_backend:
347
347
  return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
348
348
 
349
349
  if self.BACKEND==self.TORCH:
350
- print(x.shape)
351
350
  tmp=self.backend.nn.functional.interpolate(x,
352
351
  size=shape,
353
352
  mode='bilinear',
354
353
  align_corners=False)
355
- print(tmp.shape)
356
354
  return self.bk_cast(tmp)
357
355
  if self.BACKEND==self.NUMPY:
358
356
  return self.bk_cast(self.backend.image.resize(x,shape, method='bilinear'))
foscat/scat.py CHANGED
@@ -1048,7 +1048,7 @@ class funct(FOC.FoCUS):
1048
1048
  if mask is not None:
1049
1049
  if list(image1.shape)!=list(mask.shape)[1:]:
1050
1050
  print('The mask should have the same size than the input image to eval Scattering')
1051
- print(image1.shape,mask.shape)
1051
+ print('Image shape ',image1.shape,'Mask shape ',mask.shape)
1052
1052
  exit(0)
1053
1053
  if self.use_2D and len(image1.shape)<2:
1054
1054
  print('To work with 2D scattering transform, two dimension is needed, input map has only on dimension')
@@ -1129,7 +1129,6 @@ class funct(FOC.FoCUS):
1129
1129
  else:
1130
1130
  # if the kernel size is bigger than 3 increase the binning before smoothing
1131
1131
  if self.use_2D:
1132
- print(axis,image1.shape)
1133
1132
  l_image1=self.up_grade(l_image1,I1.shape[axis]*4,axis=axis,nouty=I1.shape[axis+1]*4)
1134
1133
  vmask=self.up_grade(vmask,I1.shape[axis]*4,axis=1,nouty=I1.shape[axis+1]*4)
1135
1134
  else:
foscat/scat1D.py CHANGED
@@ -29,7 +29,41 @@ class scat1D:
29
29
  self.j2=j2
30
30
  self.cross=cross
31
31
  self.backend=backend
32
+
33
+ # ---------------------------------------------−---------
34
+ def build_flat(self,table):
35
+ shape=table.shape
36
+ ndata=1
37
+ for k in range(1,len(table.shape)):
38
+ ndata=ndata*table.shape[k]
39
+ return self.backend.bk_reshape(table,[table.shape[0],ndata])
32
40
 
41
+ # ---------------------------------------------−---------
42
+ def flatten(self,S2L=False):
43
+ if not S2L:
44
+ if isinstance(self.P00,np.ndarray):
45
+ return np.concatenate([self.build_flat(lf.S0),
46
+ self.build_flat(lf.S1),
47
+ self.build_flat(lf.P00),
48
+ self.build_flat(lf.S2)],1)
49
+ else:
50
+ return self.backend.bk_concat([self.build_flat(self.S0),
51
+ self.build_flat(self.S1),
52
+ self.build_flat(self.P00),
53
+ self.build_flat(self.S2)],1)
54
+ else:
55
+ if isinstance(self.P00,np.ndarray):
56
+ return np.concatenate([self.build_flat(lf.S0),
57
+ self.build_flat(lf.S1),
58
+ self.build_flat(lf.P00),
59
+ self.build_flat(lf.S2),
60
+ self.build_flat(lf.S2L)],1)
61
+ else:
62
+ return self.backend.bk_concat([self.build_flat(self.S0),
63
+ self.build_flat(self.S1),
64
+ self.build_flat(self.P00),
65
+ self.build_flat(self.S2),
66
+ self.build_flat(self.S2L)],1)
33
67
  def get_j_idx(self):
34
68
  return self.j1,self.j2
35
69
 
@@ -576,8 +610,6 @@ class scat1D:
576
610
  s2=self.S2.numpy()
577
611
  s2l=self.S2L.numpy()
578
612
 
579
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
580
-
581
613
  if isinstance(threshold,scat1D):
582
614
  if isinstance(threshold.S1,np.ndarray):
583
615
  s1th=threshold.S1
@@ -667,7 +699,6 @@ class scat1D:
667
699
  s2l[:,i0]=s2l[:,i1]
668
700
  else:
669
701
  idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
670
- print(i0,i2)
671
702
  if len(idx[0])>0:
672
703
  s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
673
704
  idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
@@ -685,7 +716,6 @@ class scat1D:
685
716
  p0[np.isnan(p0)]=0.0
686
717
  s2[np.isnan(s2)]=0.0
687
718
  s2l[np.isnan(s2l)]=0.0
688
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
689
719
 
690
720
  return scat1D(self.backend.constant(p0),self.S0,
691
721
  self.backend.constant(s1),
@@ -819,9 +849,14 @@ class funct(FOC.FoCUS):
819
849
  def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,Add_R45=False,axis=0):
820
850
  # Check input consistency
821
851
  if mask is not None:
822
- if list(image1.shape)!=list(mask.shape)[1:]:
823
- print('The mask should have the same size than the input timeline to eval Scattering')
824
- exit(0)
852
+ if len(image1.shape)==1:
853
+ if image1.shape[0]!=mask.shape[1]:
854
+ print('The mask should have the same size than the input timeline to eval Scattering')
855
+ exit(0)
856
+ else:
857
+ if image1.shape[1]!=mask.shape[1]:
858
+ print('The mask should have the same size than the input timeline to eval Scattering')
859
+ exit(0)
825
860
 
826
861
  ### AUTO OR CROSS
827
862
  cross = False
@@ -834,7 +869,7 @@ class funct(FOC.FoCUS):
834
869
  # determine jmax and nside corresponding to the input map
835
870
  im_shape = image1.shape
836
871
 
837
- nside=im_shape[axis]
872
+ nside=im_shape[len(image1.shape)-1]
838
873
  npix=nside
839
874
 
840
875
  jmax=int(np.log(nside)/np.log(2)) #-self.OSTEP
@@ -869,11 +904,19 @@ class funct(FOC.FoCUS):
869
904
  l_image1=I1
870
905
  if cross:
871
906
  l_image2=I2
872
-
873
- s0 = self.backend.bk_reduce_sum(l_image1*vmask,axis=axis+1)+s0_off
907
+ if len(image1.shape)==1:
908
+ s0 = self.backend.bk_reduce_sum(l_image1*vmask,axis=axis+1)
909
+ if cross and Auto==False:
910
+ s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(l_image2*vmask,axis=axis)])
911
+ else:
912
+ lmask=self.backend.bk_expand_dims(vmask,0)
913
+ lim=self.backend.bk_expand_dims(l_image1,1)
914
+ s0 = self.backend.bk_reduce_sum(lim*lmask,axis=axis+2)
915
+ if cross and Auto==False:
916
+ lim=self.backend.bk_expand_dims(l_image2,1)
917
+ s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(lim*lmask,axis=axis+2)])
918
+
874
919
 
875
- if cross and Auto==False:
876
- s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(l_image2*vmask,axis=axis)+s0_off])
877
920
 
878
921
  s1=None
879
922
  s2=None
foscat/scat_cov.py CHANGED
@@ -914,6 +914,7 @@ class scat_cov:
914
914
  tabnx=[]
915
915
  tab2x=[]
916
916
  tab2nx=[]
917
+ ntmp=ntmp*ntmp
917
918
  if len(tmp.shape)>4:
918
919
  for i0 in range(tmp.shape[0]):
919
920
  for i1 in range(tmp.shape[1]):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.0.20
3
+ Version: 3.0.21
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Home-page: https://github.com/jmdelouis/FOSCAT
6
6
  Author: Jean-Marc DELOUIS
@@ -1,24 +1,25 @@
1
+ foscat/CNN.py,sha256=Evv7SkCtbzDKGlC0Gz9xFBtb-xLi3IBCYi2NsXLA9DY,4259
1
2
  foscat/CircSpline.py,sha256=610sgsWeZzRXYh7gYEqUmGQVrXoHSaFGKjH5mCdh4jU,1684
2
- foscat/FoCUS.py,sha256=dcWz86Xt_ZjiRYFZTuhK1gFbDns2hDeg9fJhG1tNsz8,67309
3
+ foscat/FoCUS.py,sha256=51ZRAhc6MsPatb3TG6lYUbvaN76r-DDJD1R2GRSqiuM,67309
3
4
  foscat/GCNN.py,sha256=TEW81DGRM4WL7RzH50VKQ-_oHbl5i3iQKuhdkkgKEO8,3831
4
5
  foscat/GetGPUinfo.py,sha256=6sJWKO_OeiA0SoGQQdCT_h3D8rZtrv_4hpBc8H3nZls,731
5
6
  foscat/Softmax.py,sha256=aCghstI2fchd8FHsBUcmPR4FGlCH9DglS7XMEWlKr8A,2709
6
7
  foscat/Spline1D.py,sha256=9oeM8SSHjpfUE5z72YxGt1RVt22vJYM1zhHbNBW8phw,1232
7
8
  foscat/Synthesis.py,sha256=oYtHFVTqalVzBQs5okJBnP4pzXFhBMds-pytdEm4Bqs,12611
8
9
  foscat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- foscat/backend.py,sha256=BJkx5CRxUHZr8WH4rvApuE8dIo1Y61fbtyfYhGByIIk,30402
10
+ foscat/backend.py,sha256=AZaMJSGoRGUSnjLKN8FF1tS16DmQZ70BjDl8zCDU9yg,30346
10
11
  foscat/backend_tens.py,sha256=zEFZ71j0nMNP9_91tz21ZVBTayr75l-sfONOLkJ8DyI,1432
11
12
  foscat/loss_backend_tens.py,sha256=WbGC4vy1pBg_bxUXnlCRiXX9WszN6MaUWUc_lUvZNvQ,1667
12
13
  foscat/loss_backend_torch.py,sha256=Fj_W3VwGgeD79eQ4jOxOmhZ548UKDRUb3JjUo2-gSWM,1755
13
- foscat/scat.py,sha256=A7xk0o3-wHh-IuXeKSkapvo2J9Hs5yMQZVuGx1FE7A8,60077
14
- foscat/scat1D.py,sha256=mSM_xDoQNoGYMV6JDmmfIX8n-Ulm6Ru8HtWqP8XTKqw,43914
14
+ foscat/scat.py,sha256=Ht_xyo7XKJJrUIbQIeucjhIrJo4RGrE63EyhTH8IYig,60061
15
+ foscat/scat1D.py,sha256=7egOWL7GXcJEenl8r1DSdArpE1Yvywgo-vxHAQ1gMzY,46269
15
16
  foscat/scat2D.py,sha256=Xtisjc5KsbLQAlbn71P0Xg1UIu3r1gUKXoYG2vIMK1M,523
16
- foscat/scat_cov.py,sha256=36Exeac1pXlnVzOWPBX0bSMMEyoUlzwK7kShDPeaF_s,107984
17
+ foscat/scat_cov.py,sha256=k_fx8aajqaBCAPmaABM1h9dg96j4QckUFvYaGG2vufQ,108007
17
18
  foscat/scat_cov1D.py,sha256=inAy_TWtUwJr6l9hX3u7r2Jud7DGy_CkjCfcjaUIdJI,58885
18
19
  foscat/scat_cov2D.py,sha256=8_XvC-lOEVUWP9vT3Wx10G_ATeVeh0SdrSWuBV7Xf5k,536
19
20
  foscat/scat_cov_map.py,sha256=ocU2xd41GtJhiU9S3dEv38KfPCvz0tJKY2f8lPxpm5c,2729
20
21
  foscat/scat_cov_map2D.py,sha256=t4llIt7DVIyU1b_u-dJSX4lBr2FhDict8RnNnHpRvHM,2754
21
- foscat-3.0.20.dist-info/METADATA,sha256=8A09buj_rHwPPsvJndDfgU4ecsQFcXzTel2RIuRZx3g,1013
22
- foscat-3.0.20.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- foscat-3.0.20.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
24
- foscat-3.0.20.dist-info/RECORD,,
22
+ foscat-3.0.21.dist-info/METADATA,sha256=95XstfOVlwkomdGlh8NhbEoO0tFSbk55BX4tOyqwNXg,1013
23
+ foscat-3.0.21.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
+ foscat-3.0.21.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
25
+ foscat-3.0.21.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5