foscat 2025.6.1__py3-none-any.whl → 2025.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/heal_NN.py CHANGED
@@ -7,18 +7,19 @@ import foscat.scat_cov as sc
7
7
  class CNN:
8
8
 
9
9
  def __init__(
10
- self,
11
- nparam=1,
12
- KERNELSZ=3,
13
- NORIENT=4,
14
- chanlist=[],
15
- in_nside=1,
16
- n_chan_in=1,
17
- SEED=1234,
18
- all_type='float32',
19
- filename=None,
20
- scat_operator=None,
21
- BACKEND='tensorflow'
10
+ self,
11
+ nparam=1,
12
+ KERNELSZ=3,
13
+ NORIENT=4,
14
+ chanlist=[],
15
+ in_nside=1,
16
+ n_chan_in=1,
17
+ SEED=1234,
18
+ add_undersample_data=False,
19
+ all_type='float32',
20
+ filename=None,
21
+ scat_operator=None,
22
+ BACKEND='tensorflow'
22
23
  ):
23
24
 
24
25
  if filename is not None:
@@ -38,6 +39,7 @@ class CNN:
38
39
  self.x = self.scat_operator.backend.bk_cast(outlist[6])
39
40
  self.out_nside = self.in_nside // (2**(self.nscale+1))
40
41
  else:
42
+ self.add_undersample_data=add_undersample_data
41
43
  self.nscale = len(chanlist)-1
42
44
  self.npar = nparam
43
45
  self.n_chan_in = n_chan_in
@@ -89,9 +91,9 @@ class CNN:
89
91
  def get_number_of_weights(self):
90
92
  totnchan = 0
91
93
  for i in range(self.nscale):
92
- totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1]
94
+ totnchan = totnchan + (self.chanlist[i]+int(self.add_undersample_data)* self.n_chan_in) * self.chanlist[i + 1]
93
95
  return (
94
- self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]*self.NORIENT
96
+ self.npar * 12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT
95
97
  + totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT*self.NORIENT
96
98
  + self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT
97
99
  )
@@ -190,7 +192,7 @@ class CNN:
190
192
  m=np.mean(m,0)
191
193
  return self.scat_operator.backend.bk_cast(m[None,None,...])
192
194
 
193
- def eval(self, im,
195
+ def eval(self, in_im,
194
196
  indices=None,
195
197
  weights=None,
196
198
  out_map=False,
@@ -204,14 +206,12 @@ class CNN:
204
206
  )
205
207
  nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT
206
208
 
207
- im = self.scat_operator.healpix_layer(im[:,:,None,:], ww)
209
+ im = self.scat_operator.healpix_layer(in_im[:,:,None,:], ww)
208
210
 
209
211
  if first_layer_rot is not None:
210
212
  im = self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,1,im.shape[3]])
211
213
  im = self.backend.bk_reduce_sum(im*first_layer_rot,2)
212
214
 
213
- if out_map:
214
- return im
215
215
 
216
216
  if activation=='relu':
217
217
  im = self.backend.bk_relu(im)
@@ -220,6 +220,16 @@ class CNN:
220
220
 
221
221
  im = self.backend.bk_reduce_sum(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,im.shape[3]//4,4]),4)
222
222
 
223
+ if self.add_undersample_data:
224
+ l_im=self.backend.bk_repeat(self.backend.bk_reshape(in_im,[in_im.shape[0],in_im.shape[1],1,in_im.shape[2]]),self.NORIENT,2)
225
+ l_im=self.backend.bk_reduce_sum(
226
+ self.backend.bk_reshape(l_im,[in_im.shape[0],in_im.shape[1],self.NORIENT,l_im.shape[3]//4,4]), 4)
227
+ im=self.backend.bk_concat([im,l_im],1)
228
+
229
+
230
+ if out_map:
231
+ return im
232
+
223
233
  for k in range(self.nscale):
224
234
  ww = self.scat_operator.backend.bk_reshape(
225
235
  x[
@@ -227,17 +237,21 @@ class CNN:
227
237
  + self.KERNELSZ
228
238
  * (self.KERNELSZ//2+1)
229
239
  * self.NORIENT*self.NORIENT
230
- * self.chanlist[k]
240
+ * (self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in)
231
241
  * self.chanlist[k + 1]
232
242
  ],
233
- [self.chanlist[k], self.NORIENT, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1], self.NORIENT],
243
+ [self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in,
244
+ self.NORIENT,
245
+ self.KERNELSZ * (self.KERNELSZ//2+1),
246
+ self.chanlist[k + 1],
247
+ self.NORIENT],
234
248
  )
235
249
  nn = (
236
250
  nn
237
251
  + self.KERNELSZ
238
252
  * (self.KERNELSZ//2+1)
239
253
  * self.NORIENT*self.NORIENT
240
- * self.chanlist[k]
254
+ * (self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in)
241
255
  * self.chanlist[k + 1]
242
256
  )
243
257
  if indices is None:
@@ -253,12 +267,17 @@ class CNN:
253
267
  im = self.backend.bk_abs(im)
254
268
  im = self.backend.bk_reduce_sum(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,im.shape[3]//4,4]),4)
255
269
 
270
+ if self.add_undersample_data:
271
+ l_im=self.backend.bk_reduce_sum(
272
+ self.backend.bk_reshape(l_im,[l_im.shape[0],l_im.shape[1],self.NORIENT,l_im.shape[3]//4,4]), 4)
273
+ im=self.backend.bk_concat([im,l_im],1)
274
+
256
275
  ww = self.scat_operator.backend.bk_reshape(
257
276
  x[
258
277
  nn : nn
259
- + self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]*self.NORIENT
278
+ + self.npar * 12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT
260
279
  ],
261
- [12 * self.out_nside**2 * self.chanlist[self.nscale]*self.NORIENT, self.npar],
280
+ [12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT, self.npar],
262
281
  )
263
282
 
264
283
  im = self.scat_operator.backend.bk_matmul(
@@ -429,4 +448,4 @@ class GCNN:
429
448
  im = self.backend.bk_reduce_mean(im[:,:,:,None]*ww,[1,2])
430
449
  #im = self.scat_operator.backend.bk_relu(im)
431
450
 
432
- return im
451
+ return im