foscat 3.1.6__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/CNN.py CHANGED
@@ -1,112 +1,151 @@
1
- import numpy as np
2
1
  import pickle
2
+
3
+ import numpy as np
4
+
3
5
  import foscat.scat_cov as sc
4
-
6
+
5
7
 
6
8
  class CNN:
7
-
8
- def __init__(self,
9
- scat_operator=None,
10
- nparam=1,
11
- nscale=1,
12
- chanlist=[],
13
- in_nside=1,
14
- n_chan_in=1,
15
- nbatch=1,
16
- SEED=1234,
17
- filename=None):
9
+
10
+ def __init__(
11
+ self,
12
+ scat_operator=None,
13
+ nparam=1,
14
+ nscale=1,
15
+ chanlist=[],
16
+ in_nside=1,
17
+ n_chan_in=1,
18
+ nbatch=1,
19
+ SEED=1234,
20
+ filename=None,
21
+ ):
18
22
 
19
23
  if filename is not None:
20
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
21
- self.scat_operator=sc.funct(KERNELSZ=outlist[3],all_type=outlist[7])
22
- self.KERNELSZ= self.scat_operator.KERNELSZ
23
- self.all_type= self.scat_operator.all_type
24
- self.npar=outlist[2]
25
- self.nscale=outlist[5]
26
- self.chanlist=outlist[0]
27
- self.in_nside=outlist[4]
28
- self.nbatch=outlist[1]
29
- self.n_chan_in=outlist[8]
30
- self.x=self.scat_operator.backend.bk_cast(outlist[6])
31
- self.out_nside=self.in_nside//(2**self.nscale)
24
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
25
+ self.scat_operator = sc.funct(KERNELSZ=outlist[3], all_type=outlist[7])
26
+ self.KERNELSZ = self.scat_operator.KERNELSZ
27
+ self.all_type = self.scat_operator.all_type
28
+ self.npar = outlist[2]
29
+ self.nscale = outlist[5]
30
+ self.chanlist = outlist[0]
31
+ self.in_nside = outlist[4]
32
+ self.nbatch = outlist[1]
33
+ self.n_chan_in = outlist[8]
34
+ self.x = self.scat_operator.backend.bk_cast(outlist[6])
35
+ self.out_nside = self.in_nside // (2**self.nscale)
32
36
  else:
33
- self.nscale=nscale
34
- self.nbatch=nbatch
35
- self.npar=nparam
36
- self.n_chan_in=n_chan_in
37
- self.scat_operator=scat_operator
38
- if len(chanlist)!=nscale+1:
39
- print('len of chanlist (here %d) should of nscale+1 (here %d)'%(len(chanlist),nscale+1))
37
+ self.nscale = nscale
38
+ self.nbatch = nbatch
39
+ self.npar = nparam
40
+ self.n_chan_in = n_chan_in
41
+ self.scat_operator = scat_operator
42
+ if len(chanlist) != nscale + 1:
43
+ print(
44
+ "len of chanlist (here %d) should of nscale+1 (here %d)"
45
+ % (len(chanlist), nscale + 1)
46
+ )
40
47
  return None
41
-
42
- self.chanlist=chanlist
43
- self.KERNELSZ= scat_operator.KERNELSZ
44
- self.all_type= scat_operator.all_type
45
- self.in_nside=in_nside
46
- self.out_nside=self.in_nside//(2**self.nscale)
48
+
49
+ self.chanlist = chanlist
50
+ self.KERNELSZ = scat_operator.KERNELSZ
51
+ self.all_type = scat_operator.all_type
52
+ self.in_nside = in_nside
53
+ self.out_nside = self.in_nside // (2**self.nscale)
47
54
 
48
55
  np.random.seed(SEED)
49
- self.x=scat_operator.backend.bk_cast(np.random.randn(self.get_number_of_weights())/(self.KERNELSZ*self.KERNELSZ))
50
-
51
-
52
-
53
- def save(self,filename):
54
-
55
- outlist=[self.chanlist, \
56
- self.nbatch, \
57
- self.npar, \
58
- self.KERNELSZ, \
59
- self.in_nside, \
60
- self.nscale, \
61
- self.get_weights().numpy(), \
62
- self.all_type, \
63
- self.n_chan_in]
64
-
65
- myout=open("%s.pkl"%(filename),"wb")
66
- pickle.dump(outlist,myout)
56
+ self.x = scat_operator.backend.bk_cast(
57
+ np.random.randn(self.get_number_of_weights())
58
+ / (self.KERNELSZ * self.KERNELSZ)
59
+ )
60
+
61
+ def save(self, filename):
62
+
63
+ outlist = [
64
+ self.chanlist,
65
+ self.nbatch,
66
+ self.npar,
67
+ self.KERNELSZ,
68
+ self.in_nside,
69
+ self.nscale,
70
+ self.get_weights().numpy(),
71
+ self.all_type,
72
+ self.n_chan_in,
73
+ ]
74
+
75
+ myout = open("%s.pkl" % (filename), "wb")
76
+ pickle.dump(outlist, myout)
67
77
  myout.close()
68
-
78
+
69
79
  def get_number_of_weights(self):
70
- totnchan=0
80
+ totnchan = 0
71
81
  for i in range(self.nscale):
72
- totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]
73
- return self.npar*12*self.out_nside**2*self.chanlist[self.nscale] \
74
- +totnchan*self.KERNELSZ*self.KERNELSZ+self.KERNELSZ*self.KERNELSZ*self.n_chan_in*self.chanlist[0]
82
+ totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1]
83
+ return (
84
+ self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]
85
+ + totnchan * self.KERNELSZ * self.KERNELSZ
86
+ + self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
87
+ )
88
+
89
+ def set_weights(self, x):
90
+ self.x = x
75
91
 
76
- def set_weights(self,x):
77
- self.x=x
78
-
79
92
  def get_weights(self):
80
93
  return self.x
81
-
82
- def eval(self,im,indices=None,weights=None):
83
94
 
84
- x=self.x
85
- ww=self.scat_operator.backend.bk_reshape(x[0:self.KERNELSZ*self.KERNELSZ*self.n_chan_in*self.chanlist[0]],
86
- [self.KERNELSZ*self.KERNELSZ,self.n_chan_in,self.chanlist[0]])
87
- nn=self.KERNELSZ*self.KERNELSZ*self.n_chan_in*self.chanlist[0]
95
+ def eval(self, im, indices=None, weights=None):
96
+
97
+ x = self.x
98
+ ww = self.scat_operator.backend.bk_reshape(
99
+ x[0 : self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]],
100
+ [self.KERNELSZ * self.KERNELSZ, self.n_chan_in, self.chanlist[0]],
101
+ )
102
+ nn = self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
103
+
104
+ im = self.scat_operator.healpix_layer(im, ww)
105
+ im = self.scat_operator.backend.bk_relu(im)
88
106
 
89
- im=self.scat_operator.healpix_layer(im,ww)
90
- im=self.scat_operator.backend.bk_relu(im)
91
-
92
107
  for k in range(self.nscale):
93
- ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]],
94
- [self.KERNELSZ*self.KERNELSZ,self.chanlist[k],self.chanlist[k+1]])
95
- nn=nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]
108
+ ww = self.scat_operator.backend.bk_reshape(
109
+ x[
110
+ nn : nn
111
+ + self.KERNELSZ
112
+ * self.KERNELSZ
113
+ * self.chanlist[k]
114
+ * self.chanlist[k + 1]
115
+ ],
116
+ [self.KERNELSZ * self.KERNELSZ, self.chanlist[k], self.chanlist[k + 1]],
117
+ )
118
+ nn = (
119
+ nn
120
+ + self.KERNELSZ
121
+ * self.KERNELSZ
122
+ * self.chanlist[k]
123
+ * self.chanlist[k + 1]
124
+ )
96
125
  if indices is None:
97
- im=self.scat_operator.healpix_layer(im,ww)
126
+ im = self.scat_operator.healpix_layer(im, ww)
98
127
  else:
99
- im=self.scat_operator.healpix_layer(im,ww,indices=indices[k],weights=weights[k])
100
- im=self.scat_operator.backend.bk_relu(im)
101
- im=self.scat_operator.ud_grade_2(im,axis=0)
102
-
103
-
104
- ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.npar*12*self.out_nside**2*self.chanlist[self.nscale]], \
105
- [12*self.out_nside**2*self.chanlist[self.nscale],self.npar])
106
-
107
- im=self.scat_operator.backend.bk_matmul(self.scat_operator.backend.bk_reshape(im,[1,12*self.out_nside**2*self.chanlist[self.nscale]]),ww)
108
- im=self.scat_operator.backend.bk_reshape(im,[self.npar])
109
- im=self.scat_operator.backend.bk_relu(im)
110
-
128
+ im = self.scat_operator.healpix_layer(
129
+ im, ww, indices=indices[k], weights=weights[k]
130
+ )
131
+ im = self.scat_operator.backend.bk_relu(im)
132
+ im = self.scat_operator.ud_grade_2(im, axis=0)
133
+
134
+ ww = self.scat_operator.backend.bk_reshape(
135
+ x[
136
+ nn : nn
137
+ + self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]
138
+ ],
139
+ [12 * self.out_nside**2 * self.chanlist[self.nscale], self.npar],
140
+ )
141
+
142
+ im = self.scat_operator.backend.bk_matmul(
143
+ self.scat_operator.backend.bk_reshape(
144
+ im, [1, 12 * self.out_nside**2 * self.chanlist[self.nscale]]
145
+ ),
146
+ ww,
147
+ )
148
+ im = self.scat_operator.backend.bk_reshape(im, [self.npar])
149
+ im = self.scat_operator.backend.bk_relu(im)
150
+
111
151
  return im
112
-
foscat/CircSpline.py CHANGED
@@ -1,6 +1,6 @@
1
-
2
1
  import math
3
2
 
3
+
4
4
  class CircSpline:
5
5
  def __init__(self, nodes, degree=3):
6
6
  """
@@ -14,7 +14,11 @@ class CircSpline:
14
14
  """
15
15
  Compute normalization factor for the ith element.
16
16
  """
17
- return pow(-1, i) * (self.degree + 1) / (math.factorial(self.degree + 1 - i) * math.factorial(i))
17
+ return (
18
+ pow(-1, i)
19
+ * (self.degree + 1)
20
+ / (math.factorial(self.degree + 1 - i) * math.factorial(i))
21
+ )
18
22
 
19
23
  def yplus(self, x):
20
24
  """
@@ -48,4 +52,3 @@ class CircSpline:
48
52
  tmp = 0.0
49
53
  y[self.nodes - 1 - i] += tmp
50
54
  return y
51
-