foscat 2025.5.2__py3-none-any.whl → 2025.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/heal_NN.py ADDED
@@ -0,0 +1,451 @@
1
+ import pickle
2
+
3
+ import numpy as np
4
+
5
+ import foscat.scat_cov as sc
6
+
7
+ class CNN:
8
+
9
+ def __init__(
10
+ self,
11
+ nparam=1,
12
+ KERNELSZ=3,
13
+ NORIENT=4,
14
+ chanlist=[],
15
+ in_nside=1,
16
+ n_chan_in=1,
17
+ SEED=1234,
18
+ add_undersample_data=False,
19
+ all_type='float32',
20
+ filename=None,
21
+ scat_operator=None,
22
+ BACKEND='tensorflow'
23
+ ):
24
+
25
+ if filename is not None:
26
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
27
+ self.scat_operator = sc.funct(KERNELSZ=outlist[3],
28
+ NORIENT= outlist[9],
29
+ all_type=outlist[7])
30
+ self.KERNELSZ = self.scat_operator.KERNELSZ
31
+ self.all_type = self.scat_operator.all_type
32
+ self.npar = outlist[2]
33
+ self.nscale = outlist[5]
34
+ self.chanlist = outlist[0]
35
+ self.in_nside = outlist[4]
36
+ self.nbatch = outlist[1]
37
+ self.n_chan_in = outlist[8]
38
+ self.NORIENT = outlist[9]
39
+ self.x = self.scat_operator.backend.bk_cast(outlist[6])
40
+ self.out_nside = self.in_nside // (2**(self.nscale+1))
41
+ else:
42
+ self.add_undersample_data=add_undersample_data
43
+ self.nscale = len(chanlist)-1
44
+ self.npar = nparam
45
+ self.n_chan_in = n_chan_in
46
+ if scat_operator is None:
47
+ self.scat_operator = sc.funct(
48
+ KERNELSZ=KERNELSZ,
49
+ NORIENT=NORIENT,
50
+ all_type=all_type)
51
+ else:
52
+ self.scat_operator = scat_operator
53
+
54
+ self.chanlist = chanlist
55
+ self.KERNELSZ = self.scat_operator.KERNELSZ
56
+ self.NORIENT = self.scat_operator.NORIENT
57
+ self.all_type = self.scat_operator.all_type
58
+ self.in_nside = in_nside
59
+ self.out_nside = self.in_nside // (2**(self.nscale+1))
60
+ self.backend = self.scat_operator.backend
61
+ np.random.seed(SEED)
62
+ self.x = self.scat_operator.backend.bk_cast(
63
+ np.random.rand(self.get_number_of_weights())
64
+ )
65
+ self.mpi_size = self.scat_operator.mpi_size
66
+ self.mpi_rank = self.scat_operator.mpi_rank
67
+ self.BACKEND = BACKEND
68
+ self.gpupos = self.scat_operator.gpupos
69
+ self.ngpu = self.scat_operator.ngpu
70
+ self.gpulist = self.scat_operator.gpulist
71
+
72
+ def save(self, filename):
73
+
74
+ outlist = [
75
+ self.chanlist,
76
+ self.nbatch,
77
+ self.npar,
78
+ self.KERNELSZ,
79
+ self.in_nside,
80
+ self.nscale,
81
+ self.get_weights().numpy(),
82
+ self.all_type,
83
+ self.n_chan_in,
84
+ self.NORIENT,
85
+ ]
86
+
87
+ myout = open("%s.pkl" % (filename), "wb")
88
+ pickle.dump(outlist, myout)
89
+ myout.close()
90
+
91
+ def get_number_of_weights(self):
92
+ totnchan = 0
93
+ for i in range(self.nscale):
94
+ totnchan = totnchan + (self.chanlist[i]+int(self.add_undersample_data)* self.n_chan_in) * self.chanlist[i + 1]
95
+ return (
96
+ self.npar * 12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT
97
+ + totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT*self.NORIENT
98
+ + self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT
99
+ )
100
+
101
+ def set_weights(self, x):
102
+ self.x = x
103
+
104
+ def get_weights(self):
105
+ return self.x
106
+
107
+ def init_wave(self):
108
+ w0=np.zeros([self.n_chan_in, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[0], self.NORIENT])
109
+ if self.KERNELSZ==3:
110
+ w0[:,0]=-0.2
111
+ w0[:,1]=-0.5
112
+ w0[:,2]=-0.2
113
+ w0[:,3]=0.2
114
+ w0[:,4]=0.5
115
+ w0[:,5]=0.2
116
+ if self.KERNELSZ==5:
117
+ w0[:,0]=-0.1
118
+ w0[:,1]=-0.2
119
+ w0[:,2]=-0.5
120
+ w0[:,3]=-0.2
121
+ w0[:,4]=-0.1
122
+ w0[:,10]=0.1
123
+ w0[:,11]=0.2
124
+ w0[:,12]=0.5
125
+ w0[:,13]=0.2
126
+ w0[:,14]=0.1
127
+
128
+ a=2*np.sqrt(6/(12 * self.out_nside**2 * self.chanlist[self.nscale]*self.NORIENT*self.npar))
129
+ x=(np.random.rand(self.get_number_of_weights())-0.5)*a
130
+
131
+ w0=w0.flatten()
132
+ x[0:w0.shape[0]]=w0
133
+ nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT
134
+
135
+ for k in range(self.nscale):
136
+ ww = np.zeros([self.chanlist[k], self.NORIENT, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1], self.NORIENT])
137
+
138
+ if self.KERNELSZ==3:
139
+ ww[:,:,0]=-0.2
140
+ ww[:,:,1]=-0.5
141
+ ww[:,:,2]=-0.2
142
+ ww[:,:,3]=0.2
143
+ ww[:,:,4]=0.5
144
+ ww[:,:,5]=0.2
145
+ if self.KERNELSZ==5:
146
+ ww[:,:,0]=-0.1
147
+ ww[:,:,1]=-0.2
148
+ ww[:,:,2]=-0.5
149
+ ww[:,:,3]=-0.2
150
+ ww[:,:,4]=-0.1
151
+ ww[:,:,10]=0.1
152
+ ww[:,:,11]=0.2
153
+ ww[:,:,12]=0.5
154
+ ww[:,:,13]=0.2
155
+ ww[:,:,14]=0.1
156
+ x[nn : nn + self.KERNELSZ
157
+ * (self.KERNELSZ//2+1)
158
+ * self.NORIENT*self.NORIENT
159
+ * self.chanlist[k]
160
+ * self.chanlist[k + 1]
161
+ ]=ww.flatten()
162
+
163
+ nn = nn + (self.KERNELSZ * (self.KERNELSZ//2+1)
164
+ * self.NORIENT*self.NORIENT
165
+ * self.chanlist[k]
166
+ * self.chanlist[k + 1])
167
+
168
+ self.x = self.scat_operator.backend.bk_cast(x)
169
+
170
+ def calc_matrix_first_layer(self,noise_map):
171
+ # Décalage circulaire par matrice de permutation
172
+ def circ_shift_matrix(N,k):
173
+ return np.roll(np.eye(N), shift=-k, axis=1)
174
+
175
+ im=self.scat_operator.convol(noise_map)
176
+ mm=np.mean(abs(im.cpu().numpy()),0)
177
+ Norient=mm.shape[1]
178
+ xx=np.cos(np.arange(Norient)/Norient*2*np.pi)
179
+ yy=np.sin(np.arange(Norient)/Norient*2*np.pi)
180
+
181
+ a=np.sum(mm*xx[None,:,None],1)
182
+ b=np.sum(mm*yy[None,:,None],1)
183
+ o=np.fmod(Norient*np.arctan2(-b,a)/(2*np.pi)+Norient,Norient)
184
+ xx=np.arange(Norient)
185
+ alpha = o[:,None,:]-xx[None,:,None]
186
+ beta = np.fmod(1+o[:,None,:]-xx[None,:,None],Norient)
187
+ alpha=(1-alpha)*(alpha<1)*(alpha>0)+beta*(beta<1)*(beta>0)
188
+
189
+ m=np.zeros([mm.shape[0],4,4,mm.shape[2]])
190
+ for k in range(4):
191
+ m[:,k,:,:]=np.roll(alpha,k,1)
192
+ m=np.mean(m,0)
193
+ return self.scat_operator.backend.bk_cast(m[None,None,...])
194
+
195
+ def eval(self, in_im,
196
+ indices=None,
197
+ weights=None,
198
+ out_map=False,
199
+ first_layer_rot=None,
200
+ activation='relu'):
201
+
202
+ x = self.x
203
+ ww = self.backend.bk_reshape(
204
+ x[0 : self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT],
205
+ [self.n_chan_in, 1 , self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[0], self.NORIENT],
206
+ )
207
+ nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT
208
+
209
+ im = self.scat_operator.healpix_layer(in_im[:,:,None,:], ww)
210
+
211
+ if first_layer_rot is not None:
212
+ im = self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,1,im.shape[3]])
213
+ im = self.backend.bk_reduce_sum(im*first_layer_rot,2)
214
+
215
+
216
+ if activation=='relu':
217
+ im = self.backend.bk_relu(im)
218
+ elif activation=='abs':
219
+ im = self.backend.bk_abs(im)
220
+
221
+ im = self.backend.bk_reduce_sum(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,im.shape[3]//4,4]),4)
222
+
223
+ if self.add_undersample_data:
224
+ l_im=self.backend.bk_repeat(self.backend.bk_reshape(in_im,[in_im.shape[0],in_im.shape[1],1,in_im.shape[2]]),self.NORIENT,2)
225
+ l_im=self.backend.bk_reduce_sum(
226
+ self.backend.bk_reshape(l_im,[in_im.shape[0],in_im.shape[1],self.NORIENT,l_im.shape[3]//4,4]), 4)
227
+ im=self.backend.bk_concat([im,l_im],1)
228
+
229
+
230
+ if out_map:
231
+ return im
232
+
233
+ for k in range(self.nscale):
234
+ ww = self.scat_operator.backend.bk_reshape(
235
+ x[
236
+ nn : nn
237
+ + self.KERNELSZ
238
+ * (self.KERNELSZ//2+1)
239
+ * self.NORIENT*self.NORIENT
240
+ * (self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in)
241
+ * self.chanlist[k + 1]
242
+ ],
243
+ [self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in,
244
+ self.NORIENT,
245
+ self.KERNELSZ * (self.KERNELSZ//2+1),
246
+ self.chanlist[k + 1],
247
+ self.NORIENT],
248
+ )
249
+ nn = (
250
+ nn
251
+ + self.KERNELSZ
252
+ * (self.KERNELSZ//2+1)
253
+ * self.NORIENT*self.NORIENT
254
+ * (self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in)
255
+ * self.chanlist[k + 1]
256
+ )
257
+ if indices is None:
258
+ im = self.scat_operator.healpix_layer(im, ww)
259
+ else:
260
+ im = self.scat_operator.healpix_layer(
261
+ im, ww, indices=indices[k], weights=weights[k]
262
+ )
263
+
264
+ if activation=='relu':
265
+ im = self.backend.bk_relu(im)
266
+ elif activation=='abs':
267
+ im = self.backend.bk_abs(im)
268
+ im = self.backend.bk_reduce_sum(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,im.shape[3]//4,4]),4)
269
+
270
+ if self.add_undersample_data:
271
+ l_im=self.backend.bk_reduce_sum(
272
+ self.backend.bk_reshape(l_im,[l_im.shape[0],l_im.shape[1],self.NORIENT,l_im.shape[3]//4,4]), 4)
273
+ im=self.backend.bk_concat([im,l_im],1)
274
+
275
+ ww = self.scat_operator.backend.bk_reshape(
276
+ x[
277
+ nn : nn
278
+ + self.npar * 12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT
279
+ ],
280
+ [12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT, self.npar],
281
+ )
282
+
283
+ im = self.scat_operator.backend.bk_matmul(
284
+ self.scat_operator.backend.bk_reshape(
285
+ im, [im.shape[0], im.shape[1] * im.shape[2] * im.shape[3]]
286
+ ),
287
+ ww,
288
+ )
289
+ #im = self.scat_operator.backend.bk_reshape(im, [self.npar])
290
+ #im = self.scat_operator.backend.bk_relu(im)
291
+ return im
292
+
293
+ class GCNN:
294
+
295
+ def __init__(
296
+ self,
297
+ nparam=1,
298
+ KERNELSZ=3,
299
+ NORIENT=4,
300
+ chanlist=[],
301
+ in_nside=1,
302
+ out_chan=1,
303
+ SEED=1234,
304
+ all_type='float32',
305
+ filename=None,
306
+ scat_operator=None,
307
+ BACKEND='tensorflow'
308
+ ):
309
+
310
+ if filename is not None:
311
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
312
+ self.scat_operator = sc.funct(KERNELSZ=outlist[3],NORIENT=outlist[8], all_type=outlist[7])
313
+ self.KERNELSZ = self.scat_operator.KERNELSZ
314
+ self.all_type = self.scat_operator.all_type
315
+ self.npar = outlist[2]
316
+ self.nscale = outlist[5]
317
+ self.chanlist = outlist[0]
318
+ self.in_nside = outlist[4]
319
+ self.nbatch = outlist[1]
320
+ self.NORIENT = outlist[8]
321
+ self.out_chan = outlist[9]
322
+ self.x = self.scat_operator.backend.bk_cast(outlist[6])
323
+ self.out_nside = self.in_nside // (2**self.nscale)
324
+ else:
325
+ self.nscale = len(chanlist)-1
326
+ self.npar = nparam
327
+
328
+ if scat_operator is None:
329
+ self.scat_operator = sc.funct(
330
+ KERNELSZ=KERNELSZ,
331
+ NORIENT=NORIENT,
332
+ all_type=all_type)
333
+ else:
334
+ self.scat_operator = scat_operator
335
+
336
+ self.chanlist = chanlist
337
+ self.KERNELSZ = self.scat_operator.KERNELSZ
338
+ self.NORIENT = self.scat_operator.NORIENT
339
+ self.all_type = self.scat_operator.all_type
340
+ self.in_nside = in_nside
341
+ self.out_nside = self.in_nside * (2**self.nscale)
342
+ self.out_chan = out_chan
343
+ self.backend = self.scat_operator.backend
344
+ np.random.seed(SEED)
345
+ self.x = self.scat_operator.backend.bk_cast(
346
+ np.random.rand(self.get_number_of_weights())
347
+ )
348
+ self.mpi_size = self.scat_operator.mpi_size
349
+ self.mpi_rank = self.scat_operator.mpi_rank
350
+ self.BACKEND = BACKEND
351
+ self.gpupos = self.scat_operator.gpupos
352
+ self.ngpu = self.scat_operator.ngpu
353
+ self.gpulist = self.scat_operator.gpulist
354
+
355
+ def save(self, filename):
356
+
357
+ outlist = [
358
+ self.chanlist,
359
+ self.nbatch,
360
+ self.npar,
361
+ self.KERNELSZ,
362
+ self.in_nside,
363
+ self.nscale,
364
+ self.get_weights().numpy(),
365
+ self.all_type,
366
+ self.NORIENT,
367
+ self.out_chan
368
+ ]
369
+
370
+ myout = open("%s.pkl" % (filename), "wb")
371
+ pickle.dump(outlist, myout)
372
+ myout.close()
373
+
374
+ def get_number_of_weights(self):
375
+ totnchan = 0
376
+ for i in range(self.nscale):
377
+ totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1]
378
+ return (
379
+ self.npar * 12 * self.in_nside**2 * self.chanlist[0]*self.NORIENT
380
+ + totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT*self.NORIENT
381
+ + self.chanlist[-1]*self.out_chan*self.NORIENT
382
+ )
383
+
384
+ def set_weights(self, x):
385
+ self.x = x
386
+
387
+ def get_weights(self):
388
+ return self.x
389
+
390
+ def eval(self, im, indices=None, weights=None):
391
+
392
+ x = self.x
393
+
394
+ ww = self.backend.bk_reshape(
395
+ x[0:self.npar * 12 * self.in_nside**2 * self.chanlist[0]*self.NORIENT],
396
+ [self.npar,12 * self.in_nside**2 * self.chanlist[0]*self.NORIENT],
397
+ )
398
+
399
+ im = self.scat_operator.backend.bk_matmul(im,ww)
400
+
401
+ im = self.backend.bk_reshape(im,[im.shape[0],self.chanlist[0],self.NORIENT,12 * self.in_nside**2])
402
+
403
+ nn = self.npar * 12 * self.in_nside**2 * self.chanlist[0]
404
+
405
+ for k in range(self.nscale):
406
+
407
+ im = self.scat_operator.backend.bk_relu(im)
408
+
409
+ im = self.backend.bk_reshape(
410
+ self.scat_operator.backend.bk_repeat(im,4,axis=-1),
411
+ [im.shape[0],im.shape[1],self.NORIENT,im.shape[3]*4])
412
+
413
+ ww = self.scat_operator.backend.bk_reshape(
414
+ x[
415
+ nn : nn
416
+ + self.KERNELSZ
417
+ * (self.KERNELSZ//2+1)
418
+ * self.NORIENT *self.NORIENT
419
+ * self.chanlist[k]
420
+ * self.chanlist[k + 1]
421
+ ],
422
+ [self.chanlist[k] , self.NORIENT, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1],self.NORIENT],
423
+ )
424
+ nn = (
425
+ nn
426
+ + self.KERNELSZ
427
+ * (self.KERNELSZ//2+1)
428
+ * self.NORIENT *self.NORIENT
429
+ * self.chanlist[k]
430
+ * self.chanlist[k + 1]
431
+ )
432
+
433
+ if indices is None:
434
+ im = self.scat_operator.healpix_layer(im, ww)
435
+ else:
436
+ im = self.scat_operator.healpix_layer(
437
+ im, ww, indices=indices[k], weights=weights[k]
438
+ )
439
+
440
+ ww = self.scat_operator.backend.bk_reshape(
441
+ x[
442
+ nn : nn
443
+ + self.chanlist[-1]*self.NORIENT
444
+ * self.out_chan
445
+ ],
446
+ [1,self.chanlist[-1],self.NORIENT, self.out_chan, 1],
447
+ )
448
+ im = self.backend.bk_reduce_mean(im[:,:,:,None]*ww,[1,2])
449
+ #im = self.scat_operator.backend.bk_relu(im)
450
+
451
+ return im