foscat 2025.6.1__py3-none-any.whl → 2025.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/HealSpline.py ADDED
@@ -0,0 +1,211 @@
1
+ from scipy.interpolate import interp1d
2
+ import foscat.CircSpline as sc
3
+ import foscat.Spline1D as sc1d
4
+ import numpy as np
5
+ import healpy as hp
6
+
7
+ class heal_spline:
8
+ def __init__(
9
+ self,
10
+ level):
11
+ nside=2**level
12
+ self.nside_store=2**(level//2)
13
+ self.spline_tree={}
14
+
15
+ self.nside=nside
16
+ #compute colatitude
17
+ idx_th=np.zeros([4*nside],dtype='int')
18
+ n=0
19
+ d=0
20
+ for k in range(nside):
21
+ d+=4
22
+ idx_th[k]=n
23
+ n+=d
24
+
25
+ for k in range(2*nside-1):
26
+ idx_th[k+nside]=n
27
+ n+=d
28
+
29
+ for k in range(nside):
30
+ idx_th[k+3*nside-1]=n
31
+ n+=d
32
+ d-=4
33
+ idx_th[-1]=12*nside**2
34
+
35
+ th0_val,ph0_val=hp.pix2ang(self.nside,idx_th[:-1])
36
+ self.th0_val=th0_val
37
+ self.ph0_val=ph0_val
38
+
39
+ self.idx_th=idx_th
40
+
41
+ #init spline table
42
+
43
+ self.spline_lat=sc1d.Spline1D(4*self.nside-1,3)
44
+
45
+ #convert colatitude in ring index
46
+ self.f_interp_th = interp1d(np.concatenate([[0],(th0_val[:-1]+th0_val[1:])/2,[np.pi]],0),
47
+ np.arange(4*self.nside)/(4*self.nside),
48
+ kind='cubic', fill_value='extrapolate')
49
+
50
+
51
+ def ang2weigths(self,th,ph,threshold=1E-2,nest=True):
52
+ th0=self.f_interp_th(th).flatten()
53
+
54
+ idx_lat,w_th=self.spline_lat.eval(th0.flatten())
55
+
56
+ www = np.zeros([4,4,th0.shape[0]])
57
+ all_idx = np.zeros([4,4,th0.shape[0]],dtype='int')
58
+
59
+ iring_tab=np.unique(idx_lat)
60
+ for iring in iring_tab:
61
+ spline_table=sc.CircSpline(self.idx_th[iring+1]-self.idx_th[iring],3)
62
+ for k in range(4):
63
+ iii=np.where(idx_lat[k]==iring)[0]
64
+ idx,w=spline_table.eval((ph[iii]-self.ph0_val[iring])/(2*np.pi))
65
+ idx=idx+self.idx_th[iring]
66
+ for m in range(4):
67
+ www[k,m,iii]=w[m]*w_th[k,iii]
68
+ all_idx[k,m,iii]=idx[m]
69
+
70
+ www=www.reshape(16,www.shape[2])
71
+ all_idx=all_idx.reshape(16,all_idx.shape[2])
72
+
73
+ heal_idx,inv_idx = np.unique(all_idx,
74
+ return_inverse=True)
75
+ all_idx = inv_idx
76
+ if nest:
77
+ heal_idx = hp.ring2nest(self.nside,heal_idx)
78
+ self.cell_ids = heal_idx
79
+
80
+ hit=np.bincount(all_idx.flatten(),weights=www.flatten())
81
+ www[hit[all_idx]<threshold]=0.0
82
+ www=www/np.sum(www,0)[None,:]
83
+ return www,all_idx,heal_idx
84
+
85
+ def P(self,x,www,all_idx):
86
+ return np.sum(www*x[all_idx],0)
87
+
88
+ #PT(y) must return a 1D NumPy array of shape (N,)
89
+ def PT(self,y,www,all_idx,hit):
90
+ value=np.bincount(all_idx.flatten(),weights=(www*y[None,:]).flatten())
91
+ return value*hit
92
+
93
+ # the data is of dimension M
94
+ # the x is of dimension N=12*nside**2
95
+
96
+ def conjugate_gradient_normal_equation(self,data, x0, www, all_idx, max_iter=100, tol=1e-8, verbose=True):
97
+ """
98
+ Solve (PᵗP)x = Pᵗy using explicit Conjugate Gradient without scipy.cg.
99
+
100
+ Parameters:
101
+ ----------
102
+ P : function(x) → forward operator (ℝⁿ → ℝᵐ)
103
+ PT : function(y) → adjoint operator (ℝᵐ → ℝⁿ)
104
+ data : array_like, observed data y ∈ ℝᵐ
105
+ x0 : array_like, initial guess for x ∈ ℝⁿ
106
+ max_iter: maximum number of iterations
107
+ tol : convergence tolerance on relative residual
108
+ verbose : if True, print convergence info
109
+
110
+ Returns:
111
+ -------
112
+ x : estimated solution ∈ ℝⁿ
113
+ """
114
+ x = x0.copy()
115
+
116
+ hit=np.bincount(all_idx.flatten(),weights=www.flatten())
117
+ hit[hit>0]=1/hit[hit>0]
118
+
119
+ # Compute b = Pᵗ y # This part could be distributed easily
120
+ b = self.PT(data,www,all_idx,hit)
121
+
122
+ # Compute initial residual r = b - A x = b - Pᵗ P x
123
+ Ax = self.PT(self.P(x,www,all_idx),www,all_idx,hit)# This part could be distributed easily
124
+ r = b - Ax
125
+
126
+ # Initial direction
127
+ p = r.copy()
128
+ rs_old = np.dot(r, r)
129
+
130
+ for i in range(max_iter):
131
+ # Apply A p = Pᵗ P p
132
+ Ap = self.PT(self.P(p,www,all_idx),www,all_idx,hit)# This part could be distributed easily
133
+
134
+ alpha = rs_old / np.dot(p, Ap)
135
+ x += alpha * p
136
+ r -= alpha * Ap
137
+
138
+ rs_new = np.dot(r, r)
139
+
140
+ if verbose and i%50==0:
141
+ print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e}")
142
+
143
+ if np.sqrt(rs_new) < tol:
144
+ if verbose:
145
+ print(f"Converged. Iter {i:03d}: residual = {np.sqrt(rs_new):.3e}")
146
+ break
147
+
148
+ p = r + (rs_new / rs_old) * p
149
+ rs_old = rs_new
150
+
151
+ return x
152
+
153
+ def Fit(self,X,th,ph,threshold=1E-2,nest=True, max_iter=100, tol=1e-8):
154
+
155
+ www,all_idx,hidx=self.ang2weigths(th,ph,threshold=threshold,nest=nest)
156
+
157
+ self.heal_idx=hidx
158
+ self.spline=self.conjugate_gradient_normal_equation(X,
159
+ (self.heal_idx*0).astype('float'),
160
+ www,
161
+ all_idx
162
+ , max_iter=max_iter,
163
+ tol=tol
164
+ )
165
+ scale=(self.nside//self.nside_store)**2
166
+ h,ih=np.unique(hidx//scale,return_inverse=True)
167
+ for k in range(h.shape[0]):
168
+ spl=np.zeros([scale])
169
+ spl[hidx[ih==k]-scale*h[k]]=self.spline[ih==k]
170
+ self.spline_tree[h[k]]=spl
171
+
172
+ def SetParam(self,nside,spline,heal_idx):
173
+
174
+ self.heal_idx=heal_idx
175
+ self.nside=nside
176
+ self.spline=spline
177
+ self.level=int(np.log2(nside))
178
+ self.nside_store=2**(int(self.level//2))
179
+
180
+ self.spline_tree={}
181
+
182
+ scale=(self.nside//self.nside_store)**2
183
+ h,ih=np.unique(heal_idx//scale,return_inverse=True)
184
+ for k in range(h.shape[0]):
185
+ spl=np.zeros([scale])
186
+ spl[heal_idx[ih==k]-scale*h[k]]=self.spline[ih==k]
187
+ self.spline_tree[h[k]]=spl
188
+
189
+
190
+
191
+ def Transform(self,th,ph,threshold=1E-2,nest=True):
192
+
193
+ www,all_idx,hidx=self.ang2weigths(th,ph,threshold=threshold,nest=nest)
194
+
195
+ x=np.zeros([hidx.shape[0]])
196
+ scale=(self.nside//self.nside_store)**2
197
+ h,ih=np.unique(hidx//scale,return_inverse=True)
198
+ for k in range(h.shape[0]):
199
+ if h[k] in self.spline_tree:
200
+ spl=self.spline_tree[h[k]]
201
+ x[ih==k]=spl[hidx[ih==k]-scale*h[k]]
202
+ data=self.P(x,www,all_idx)
203
+ return data
204
+
205
+ def FitTransform(self,X,th,ph,threshold=1E-2,nest=True):
206
+
207
+ self.Fit(X,th,ph)
208
+
209
+ t,p=hp.pix2ang(self.nside,self.heal_idx,nest=True)
210
+
211
+ return self.Transform(t,p)
foscat/heal_NN.py CHANGED
@@ -7,18 +7,19 @@ import foscat.scat_cov as sc
7
7
  class CNN:
8
8
 
9
9
  def __init__(
10
- self,
11
- nparam=1,
12
- KERNELSZ=3,
13
- NORIENT=4,
14
- chanlist=[],
15
- in_nside=1,
16
- n_chan_in=1,
17
- SEED=1234,
18
- all_type='float32',
19
- filename=None,
20
- scat_operator=None,
21
- BACKEND='tensorflow'
10
+ self,
11
+ nparam=1,
12
+ KERNELSZ=3,
13
+ NORIENT=4,
14
+ chanlist=[],
15
+ in_nside=1,
16
+ n_chan_in=1,
17
+ SEED=1234,
18
+ add_undersample_data=False,
19
+ all_type='float32',
20
+ filename=None,
21
+ scat_operator=None,
22
+ BACKEND='tensorflow'
22
23
  ):
23
24
 
24
25
  if filename is not None:
@@ -38,6 +39,7 @@ class CNN:
38
39
  self.x = self.scat_operator.backend.bk_cast(outlist[6])
39
40
  self.out_nside = self.in_nside // (2**(self.nscale+1))
40
41
  else:
42
+ self.add_undersample_data=add_undersample_data
41
43
  self.nscale = len(chanlist)-1
42
44
  self.npar = nparam
43
45
  self.n_chan_in = n_chan_in
@@ -89,9 +91,9 @@ class CNN:
89
91
  def get_number_of_weights(self):
90
92
  totnchan = 0
91
93
  for i in range(self.nscale):
92
- totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1]
94
+ totnchan = totnchan + (self.chanlist[i]+int(self.add_undersample_data)* self.n_chan_in) * self.chanlist[i + 1]
93
95
  return (
94
- self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]*self.NORIENT
96
+ self.npar * 12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT
95
97
  + totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT*self.NORIENT
96
98
  + self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT
97
99
  )
@@ -190,7 +192,7 @@ class CNN:
190
192
  m=np.mean(m,0)
191
193
  return self.scat_operator.backend.bk_cast(m[None,None,...])
192
194
 
193
- def eval(self, im,
195
+ def eval(self, in_im,
194
196
  indices=None,
195
197
  weights=None,
196
198
  out_map=False,
@@ -204,14 +206,12 @@ class CNN:
204
206
  )
205
207
  nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT
206
208
 
207
- im = self.scat_operator.healpix_layer(im[:,:,None,:], ww)
209
+ im = self.scat_operator.healpix_layer(in_im[:,:,None,:], ww)
208
210
 
209
211
  if first_layer_rot is not None:
210
212
  im = self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,1,im.shape[3]])
211
213
  im = self.backend.bk_reduce_sum(im*first_layer_rot,2)
212
214
 
213
- if out_map:
214
- return im
215
215
 
216
216
  if activation=='relu':
217
217
  im = self.backend.bk_relu(im)
@@ -220,6 +220,16 @@ class CNN:
220
220
 
221
221
  im = self.backend.bk_reduce_sum(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,im.shape[3]//4,4]),4)
222
222
 
223
+ if self.add_undersample_data:
224
+ l_im=self.backend.bk_repeat(self.backend.bk_reshape(in_im,[in_im.shape[0],in_im.shape[1],1,in_im.shape[2]]),self.NORIENT,2)
225
+ l_im=self.backend.bk_reduce_sum(
226
+ self.backend.bk_reshape(l_im,[in_im.shape[0],in_im.shape[1],self.NORIENT,l_im.shape[3]//4,4]), 4)
227
+ im=self.backend.bk_concat([im,l_im],1)
228
+
229
+
230
+ if out_map:
231
+ return im
232
+
223
233
  for k in range(self.nscale):
224
234
  ww = self.scat_operator.backend.bk_reshape(
225
235
  x[
@@ -227,17 +237,21 @@ class CNN:
227
237
  + self.KERNELSZ
228
238
  * (self.KERNELSZ//2+1)
229
239
  * self.NORIENT*self.NORIENT
230
- * self.chanlist[k]
240
+ * (self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in)
231
241
  * self.chanlist[k + 1]
232
242
  ],
233
- [self.chanlist[k], self.NORIENT, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1], self.NORIENT],
243
+ [self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in,
244
+ self.NORIENT,
245
+ self.KERNELSZ * (self.KERNELSZ//2+1),
246
+ self.chanlist[k + 1],
247
+ self.NORIENT],
234
248
  )
235
249
  nn = (
236
250
  nn
237
251
  + self.KERNELSZ
238
252
  * (self.KERNELSZ//2+1)
239
253
  * self.NORIENT*self.NORIENT
240
- * self.chanlist[k]
254
+ * (self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in)
241
255
  * self.chanlist[k + 1]
242
256
  )
243
257
  if indices is None:
@@ -253,12 +267,17 @@ class CNN:
253
267
  im = self.backend.bk_abs(im)
254
268
  im = self.backend.bk_reduce_sum(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,im.shape[3]//4,4]),4)
255
269
 
270
+ if self.add_undersample_data:
271
+ l_im=self.backend.bk_reduce_sum(
272
+ self.backend.bk_reshape(l_im,[l_im.shape[0],l_im.shape[1],self.NORIENT,l_im.shape[3]//4,4]), 4)
273
+ im=self.backend.bk_concat([im,l_im],1)
274
+
256
275
  ww = self.scat_operator.backend.bk_reshape(
257
276
  x[
258
277
  nn : nn
259
- + self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]*self.NORIENT
278
+ + self.npar * 12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT
260
279
  ],
261
- [12 * self.out_nside**2 * self.chanlist[self.nscale]*self.NORIENT, self.npar],
280
+ [12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT, self.npar],
262
281
  )
263
282
 
264
283
  im = self.scat_operator.backend.bk_matmul(
@@ -429,4 +448,4 @@ class GCNN:
429
448
  im = self.backend.bk_reduce_mean(im[:,:,:,None]*ww,[1,2])
430
449
  #im = self.scat_operator.backend.bk_relu(im)
431
450
 
432
- return im
451
+ return im