foscat 3.4.0__py3-none-any.whl → 3.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -38,7 +38,7 @@ class FoCUS:
38
38
  mpi_rank=0,
39
39
  ):
40
40
 
41
- self.__version__ = "3.4.0"
41
+ self.__version__ = "3.6.1"
42
42
  # P00 coeff for normalization for scat_cov
43
43
  self.TMPFILE_VERSION = TMPFILE_VERSION
44
44
  self.P1_dic = None
foscat/alm.py CHANGED
@@ -1,48 +1,137 @@
1
1
  import healpy as hp
2
2
  import numpy as np
3
+ import time
3
4
 
4
5
  class alm():
5
6
 
6
7
  def __init__(self,backend=None,lmax=24,
7
8
  nside=None,limit_range=1E10):
9
+
10
+ if backend is None:
11
+ import foscat.scat_cov as sc
12
+ self.sc=sc.funct()
13
+ self.backend=self.sc.backend
14
+ else:
15
+ self.backend=backend.backend
16
+
8
17
  self._logtab={}
9
18
  self.lth={}
19
+ self.lph={}
20
+ self.matrix_shift_ph={}
21
+ self.ratio_mm={}
22
+ self.P_mm={}
23
+ self.A={}
24
+ self.B={}
10
25
  if nside is not None:
26
+ self.maxlog=6*nside+1
11
27
  self.lmax=3*nside
12
- th,ph=hp.pix2ang(nside,np.arange(12*nside*nside))
13
-
14
- lth=np.unique(th)
15
-
16
- self.lth[nside]=lth
17
28
  else:
18
29
  self.lmax=lmax
19
-
20
- for k in range(1,2*self.lmax+1):
21
- self._logtab[k]=np.log(k)
30
+ self.maxlog=2*lmax+1
31
+
32
+ for k in range(1,self.maxlog):
33
+ self._logtab[k]=self.backend.bk_log(self.backend.bk_cast(k))
22
34
  self._logtab[0]=0.0
35
+
36
+ if nside is not None:
37
+ self.ring_th(nside)
38
+ self.ring_ph(nside)
39
+ self.shift_ph(nside)
40
+
23
41
  self._limit_range=1/limit_range
24
42
  self._log_limit_range=np.log(limit_range)
25
43
 
26
- if backend is None:
27
- import foscat.scat_cov as sc
28
- self.sc=sc.funct()
29
- self.backend=self.sc.backend
30
- else:
31
- self.backend=backend.backend
32
44
 
33
45
  self.Yp={}
34
46
  self.Ym={}
35
47
 
36
48
  def ring_th(self,nside):
37
49
  if nside not in self.lth:
38
- th,ph=hp.pix2ang(nside,np.arange(12*nside*nside))
39
-
40
- lth=np.unique(th)
41
-
42
- self.lth[nside]=lth
50
+ n=0
51
+ ith=[]
52
+ for k in range(nside-1):
53
+ N=4*(k+1)
54
+ ith.append(n)
55
+ n+=N
56
+
57
+ for k in range(2*nside+1):
58
+ N=4*nside
59
+ ith.append(n)
60
+ n+=N
61
+ for k in range(nside-1):
62
+ N=4*(nside-1-k)
63
+ ith.append(n)
64
+ n+=N
65
+
66
+ th,ph=hp.pix2ang(nside,ith)
67
+
68
+ self.lth[nside]=th
43
69
  return self.lth[nside]
70
+
71
+ def ring_ph(self,nside):
72
+ if nside not in self.lph:
73
+ n=0
74
+ iph=[]
75
+ for k in range(nside-1):
76
+ N=4*(k+1)
77
+ iph.append(n)
78
+ n+=N
79
+
80
+ for k in range(2*nside+1):
81
+ N=4*nside
82
+ iph.append(n)
83
+ n+=N
84
+ for k in range(nside-1):
85
+ N=4*(nside-1-k)
86
+ iph.append(n)
87
+ n+=N
88
+
89
+ th,ph=hp.pix2ang(nside,iph)
90
+
91
+ self.lph[nside]=ph
44
92
 
93
+ def shift_ph(self,nside):
94
+
95
+ if nside not in self.matrix_shift_ph:
96
+ self.ring_th(nside)
97
+ self.ring_ph(nside)
98
+ x=(-1J*np.arange(3*nside)).reshape(1,3*nside)
99
+ self.matrix_shift_ph[nside]=self.backend.bk_cast(self.backend.bk_exp(x*self.lph[nside].reshape(4*nside-1,1)))
45
100
 
101
+ self.lmax=3*nside-1
102
+
103
+ ratio_mm={}
104
+
105
+ for m in range(3*nside):
106
+ val=np.zeros([self.lmax-m+1])
107
+ aval=np.zeros([self.lmax-m+1])
108
+ bval=np.zeros([self.lmax-m+1])
109
+
110
+ if m>0:
111
+ val[0]=self.double_factorial_log(2*m - 1)-0.5*np.sum(np.log(1+np.arange(2*m)))
112
+ else:
113
+ val[0]=self.double_factorial_log(2*m - 1)
114
+ if m<self.lmax:
115
+ aval[1]=(2*m + 1)
116
+ val[1] = val[0]-0.5*self.log(2*m+1)
117
+
118
+ for l in range(m + 2, self.lmax+1):
119
+ aval[l-m]=(2*l - 1)/ (l - m)
120
+ bval[l-m]=(l + m - 1)/ (l - m)
121
+ val[l-m] = val[l-m-1] + 0.5*self.log(l-m) - 0.5*self.log(l+m)
122
+
123
+ self.A[nside,m]=self.backend.constant((aval))
124
+ self.B[nside,m]=self.backend.constant((bval))
125
+ self.ratio_mm[nside,m]=self.backend.constant(np.sqrt(4*np.pi)*np.expand_dims(np.exp(val),1))
126
+ # Calcul de P_{mm}(x)
127
+ P_mm=np.ones([3*nside,4*nside-1])
128
+ x=np.cos(self.lth[nside])
129
+ if m == 0:
130
+ P_mm[m] = 1.0
131
+ for m in range(3*nside-1):
132
+ P_mm[m] = (0.5-m%2)*2 * (1 - x**2)**(m/2)
133
+ self.P_mm[nside]=self.backend.constant(P_mm)
134
+
46
135
  def init_Ys(self,s,nside):
47
136
 
48
137
  if (s,nside) not in self.Yp:
@@ -52,7 +141,7 @@ class alm():
52
141
  ell_max = 3*nside-1 # Use the largest ℓ value you expect to need
53
142
  wigner = spherical.Wigner(ell_max)
54
143
 
55
- th,ph=hp.pix2ang(nside,np.arange(12*nside*nside))
144
+ #th,ph=hp.pix2ang(nside,np.arange(12*nside*nside))
56
145
 
57
146
  lth=self.ring_th(nside)
58
147
 
@@ -64,21 +153,22 @@ class alm():
64
153
 
65
154
  for m in range(ell_max+1):
66
155
  idx=np.array([wigner.Yindex(k, m) for k in range(m,ell_max+1)])
67
- self.Yp[s,nside][m] = iplus[idx]
68
- self.Ym[s,nside][m] = imoins[idx]
156
+ vnorm=1/np.expand_dims(np.sqrt(2*(np.arange(ell_max-m+1)+m)+1),1)
157
+ self.Yp[s,nside][m] = iplus[idx]*vnorm
158
+ self.Ym[s,nside][m] = imoins[idx]*vnorm
69
159
 
70
160
  del(iplus)
71
161
  del(imoins)
72
162
  del(wigner)
73
163
 
74
164
  def log(self,v):
75
- #return np.log(v)
165
+ return np.log(v)
76
166
  if isinstance(v,np.ndarray):
77
- return np.array([self.log(k) for k in v])
78
- if v<self.lmax*2+1:
167
+ return np.array([self.backend.bk_log(self.backend.bk_cast(k)) for k in v])
168
+ if v<self.maxlog:
79
169
  return self._logtab[v]
80
170
  else:
81
- self._logtab[v]=np.log(v)
171
+ self._logtab[v]=self.backend.bk_log(self.backend.bk_cast(v))
82
172
  return self._logtab[v]
83
173
 
84
174
  # Fonction pour calculer la double factorielle
@@ -87,11 +177,21 @@ class alm():
87
177
  return 0.0
88
178
  result = 0.0
89
179
  for i in range(n, 0, -2):
90
- result += self.log(i)
180
+ result += np.log(i)
91
181
  return result
92
182
 
93
- # Calcul des P_{lm}(x) pour tout l inclus dans [m,lmax]
94
- def compute_legendre_m(self,x,m,lmax):
183
+ def recurrence_fn(self,states, inputs):
184
+ """
185
+ Fonction de récurrence pour tf.scan.
186
+ states: un tuple (U_{n-1}, U_{n-2}) de forme [m]
187
+ inputs: un tuple (a_n(x), b_n) où a_n(x) est de forme [m]
188
+ """
189
+ U_prev, U_prev2 = states
190
+ a_n, b_n = inputs # a_n est de forme [m], b_n est un scalaire
191
+ U_n = a_n * U_prev - b_n * U_prev2
192
+ return (U_n, U_prev) # Avancer les états
193
+ # Calcul des P_{lm}(x) pour tout l inclus dans [m,lmax]
194
+ def compute_legendre_m(self,x,m,lmax,nside):
95
195
  result=np.zeros([lmax-m+1,x.shape[0]])
96
196
  ratio=np.zeros([lmax-m+1,1])
97
197
 
@@ -109,7 +209,7 @@ class alm():
109
209
  result[0] = Pmm
110
210
 
111
211
  if m == lmax:
112
- return result*np.exp(ratio)*np.sqrt(4*np.pi*(2*(np.arange(lmax-m+1)+m)+1)).reshape(lmax+1-m,1)
212
+ return result*np.exp(ratio)*np.sqrt(4*np.pi)
113
213
 
114
214
  # Étape 2 : Calcul de P_{l+1, m}(x)
115
215
  result[1] = x * (2*m + 1) * result[0]
@@ -126,7 +226,81 @@ class alm():
126
226
  ratio[l-m-1,0]+= self._log_limit_range
127
227
  ratio[l-m,0]+= self._log_limit_range
128
228
 
129
- return result*np.exp(ratio)*(np.sqrt(4*np.pi*(2*(np.arange(lmax-m+1)+m)+1))).reshape(lmax+1-m,1)
229
+ return result*np.exp(ratio)*np.sqrt(4*np.pi)
230
+
231
+ # Calcul des P_{lm}(x) pour tout l inclus dans [m,lmax]
232
+ def compute_legendre_m_old2(self,x,m,lmax,nside):
233
+
234
+ result={}
235
+
236
+ # Si l == m, c'est directement P_{mm}
237
+ result[0] = self.P_mm[nside][m]
238
+
239
+ if m == lmax:
240
+ v=self.backend.bk_reshape(result[0]*self.ratio_mm[nside,m][0],[1,4*nside-1])
241
+ return self.backend.bk_complex(v,0*v)
242
+
243
+ # Étape 2 : Calcul de P_{l+1, m}(x)
244
+ result[1] = x * self.A[nside,m][1] * result[0]
245
+
246
+ # Étape 3 : Récurence pour l > m + 1
247
+ for l in range(m + 2, lmax+1):
248
+ result[l-m] = self.A[nside,m][l-m] * x * result[l-m-1] - self.B[nside,m][l-m] * result[l-m-2]
249
+ """
250
+ if np.max(abs(result[l-m]))>self._limit_range:
251
+ result[l-m-1]*= self._limit_range
252
+ result[l-m]*= self._limit_range
253
+ ratio[l-m-1]+= self._log_limit_range
254
+ ratio[l-m]+= self._log_limit_range
255
+ """
256
+ result=self.backend.bk_reshape(self.backend.bk_concat([result[k] for k in range(lmax+1-m)],axis=0),[lmax+1-m,4*nside-1])
257
+
258
+ return self.backend.bk_complex(result*self.ratio_mm[nside,m],0*result)
259
+
260
+
261
+ def compute_legendre_m_old(self,x,m,lmax,nside):
262
+
263
+ import tensorflow as tf
264
+ result={}
265
+
266
+ # Si l == m, c'est directement P_{mm}
267
+ U_0 = self.P_mm[nside][m]
268
+
269
+ if m == lmax:
270
+ v=self.backend.bk_reshape(U_0*self.ratio_mm[nside,m][0],[1,4*nside-1])
271
+ return self.backend.bk_complex(v,0*v)
272
+
273
+ # Étape 2 : Calcul de P_{l+1, m}(x)
274
+ U_1 = x * self.A[nside,m][1] * U_0
275
+ if m == lmax-1:
276
+ result = tf.concat([self.backend.bk_expand_dims(U_0,0),
277
+ self.backend.bk_expand_dims(U_1,0)],0)
278
+ return self.backend.bk_complex(result*self.ratio_mm[nside,m],0*result)
279
+
280
+ a_values = self.backend.bk_expand_dims(self.A[nside,m],1)*self.backend.bk_expand_dims(x,0)
281
+ # Initialiser les états avec (U_1, U_0) pour chaque m
282
+ initial_states = (U_1, U_0)
283
+ inputs = (a_values[2:], self.B[nside,m][2:])
284
+ # Appliquer tf.scan
285
+ result = tf.scan(self.recurrence_fn, inputs, initializer=initial_states)
286
+ # Le premier élément de result contient les U[n]
287
+ result = tf.concat([self.backend.bk_expand_dims(U_0,0),
288
+ self.backend.bk_expand_dims(U_1,0),
289
+ result[0]], axis=0)
290
+ """
291
+ # Étape 3 : Récurence pour l > m + 1
292
+ for l in range(m + 2, lmax+1):
293
+ result[l-m] = self.A[nside,m][l-m] * x * result[l-m-1] - self.B[nside,m][l-m] * result[l-m-2]
294
+
295
+ if np.max(abs(result[l-m]))>self._limit_range:
296
+ result[l-m-1]*= self._limit_range
297
+ result[l-m]*= self._limit_range
298
+ ratio[l-m-1]+= self._log_limit_range
299
+ ratio[l-m]+= self._log_limit_range
300
+ result=self.backend.bk_reshape(self.backend.bk_concat([result[k] for k in range(lmax+1-m)],axis=0),[lmax+1-m,4*nside-1])
301
+ """
302
+
303
+ return self.backend.bk_complex(result*self.ratio_mm[nside,m],0*result)
130
304
 
131
305
 
132
306
  # Calcul des s_P_{lm}(x) pour tout l inclus dans [m,lmax]
@@ -194,45 +368,139 @@ class alm():
194
368
  ylm_moins = alpha_moins*ylm[1:] + beta_moins*ylm[:-1]
195
369
 
196
370
  return ylm_plus,ylm_moins
371
+
372
+ def rfft2fft(self,val,axis=0):
373
+ r=self.backend.bk_rfft(val)
374
+ if axis==0:
375
+ r_inv=self.backend.bk_reverse(self.backend.bk_conjugate(r[1:-1]),axis=axis)
376
+ else:
377
+ r_inv=self.backend.bk_reverse(self.backend.bk_conjugate(r[:,1:-1]),axis=axis)
378
+ return self.backend.bk_concat([r,r_inv],axis=axis)
197
379
 
198
- def comp_tf(self,im,ph):
199
- nside=int(np.sqrt(im.shape[0]//12))
380
+ def irfft2fft(self,val,N,axis=0):
381
+ if axis==0:
382
+ return self.backend.bk_irfft(val[0:N//2+1])
383
+ else:
384
+ return self.backend.bk_irfft(val[:,0:N//2+1])
385
+
386
+ def comp_tf(self,im,nside,realfft=False):
387
+
388
+ self.shift_ph(nside)
200
389
  n=0
201
- ii=0
390
+
202
391
  ft_im=[]
203
392
  for k in range(nside-1):
204
393
  N=4*(k+1)
205
- l_n=N
206
- if l_n>3*nside:
207
- l_n=3*nside
208
- tmp=self.backend.bk_fft(im[n:n+N])[0:l_n]
209
- ft_im.append(tmp*np.exp(-1J*np.arange(l_n)*ph[n]))
210
- ft_im.append(self.backend.bk_zeros((3*nside-l_n),dtype=self.backend.all_cbk_type))
211
- # if N<3*nside fill the tf with rotational values to mimic alm_tools.F90 of healpix (Minor effect)
212
- #for m in range(l_n,3*nside,l_n):
213
- # ft_im.append(tmp[0:np.min([3*nside-m,l_n])])
214
- n+=N
215
- ii+=1
216
- for k in range(2*nside+1):
217
- N=4*nside
218
- ft_im.append(self.backend.bk_fft(im[n:n+N])[:3*nside]*np.exp(-1J*np.arange(3*nside)*ph[n]))
394
+
395
+ if realfft:
396
+ tmp=self.rfft2fft(im[n:n+N])
397
+ else:
398
+ tmp=self.backend.bk_fft(im[n:n+N])
399
+
400
+ l_n=tmp.shape[0]
401
+
402
+ if l_n<3*nside+1:
403
+ repeat_n=3*nside//l_n+1
404
+ tmp=self.backend.bk_tile(tmp,repeat_n,axis=0)
405
+
406
+ ft_im.append(tmp[0:3*nside])
407
+
219
408
  n+=N
220
- ii+=1
409
+ if nside>1:
410
+ result=self.backend.bk_reshape(self.backend.bk_concat(ft_im,axis=0),[nside-1,3*nside])
411
+
412
+ N=4*nside*(2*nside+1)
413
+ v=self.backend.bk_reshape(im[n:n+N],[2*nside+1,4*nside])
414
+ if realfft:
415
+ v_fft=self.rfft2fft(v,axis=1)[:,:3*nside]
416
+ else:
417
+ v_fft=self.backend.bk_fft(v)[:,:3*nside]
418
+
419
+ n+=N
420
+ if nside>1:
421
+ result=self.backend.bk_concat([result,v_fft],axis=0)
422
+ else:
423
+ result=v_fft
424
+
425
+ if nside>1:
426
+ ft_im=[]
427
+ for k in range(nside-1):
428
+ N=4*(nside-1-k)
429
+
430
+ if realfft:
431
+ tmp=self.rfft2fft(im[n:n+N])[0:l_n]
432
+ else:
433
+ tmp=self.backend.bk_fft(im[n:n+N])[0:l_n]
434
+
435
+ l_n=tmp.shape[0]
436
+
437
+ if l_n<3*nside+1:
438
+ repeat_n=3*nside//l_n+1
439
+ tmp=self.backend.bk_tile(tmp,repeat_n,axis=0)
440
+
441
+ ft_im.append(tmp[0:3*nside])
442
+ n+=N
443
+
444
+ lastresult=self.backend.bk_reshape(self.backend.bk_concat(ft_im,axis=0),[nside-1,3*nside])
445
+ return self.backend.bk_concat([result,lastresult],axis=0)*self.matrix_shift_ph[nside]
446
+ else:
447
+ return result*self.matrix_shift_ph[nside]
448
+
449
+
450
+ def icomp_tf(self,i_im,nside,realfft=False):
451
+
452
+ self.shift_ph(nside)
453
+
454
+ n=0
455
+ im=[]
456
+ ft_im=i_im*self.backend.bk_conjugate(self.matrix_shift_ph[nside])
457
+
221
458
  for k in range(nside-1):
222
- N=4*(nside-1-k)
223
- l_n=N
224
- if l_n>3*nside:
225
- l_n=3*nside
226
- tmp=self.backend.bk_fft(im[n:n+N])[0:l_n]
227
- ft_im.append(tmp*np.exp(-1J*np.arange(l_n)*ph[n]))
228
- ft_im.append(self.backend.bk_zeros((3*nside-l_n),dtype=self.backend.all_cbk_type))
229
- # if N<3*nside fill the tf with rotational values to mimic alm_tools.F90 of healpix (Minor effect)
230
- #for m in range(l_n,3*nside,l_n):
231
- # ft_im.append(tmp[0:np.min([3*nside-m,l_n])])
459
+ N=4*(k+1)
460
+
461
+ if realfft:
462
+ tmp=self.irfft2fft(ft_im[k],N)
463
+ else:
464
+ tmp=self.backend.bk_ifft(im[k],N)
465
+
466
+ im.append(tmp[0:N])
467
+
232
468
  n+=N
233
- ii+=1
234
- return self.backend.bk_reshape(self.backend.bk_concat(ft_im,axis=0),[4*nside-1,3*nside])
469
+
470
+ if nside>1:
471
+ result=self.backend.bk_concat(im,axis=0)
472
+
473
+ N=4*nside*(2*nside+1)
474
+ v=ft_im[nside-1:3*nside,0:2*nside+1]
475
+ if realfft:
476
+ v_fft=self.backend.bk_reshape(self.irfft2fft(v,N,axis=1),[4*nside*(2*nside+1)])
477
+ else:
478
+ v_fft=self.backend.bk_ifft(v)
479
+
480
+ n+=N
481
+ if nside>1:
482
+ result=self.backend.bk_concat([result,v_fft],axis=0)
483
+ else:
484
+ result=v_fft
485
+
486
+ if nside>1:
487
+ im=[]
488
+ for k in range(nside-1):
489
+ N=4*(nside-1-k)
490
+
491
+ if realfft:
492
+ tmp=self.irfft2fft(ft_im[k+3*nside],N)
493
+ else:
494
+ tmp=self.backend.bk_ifft(im[k+3*nside],N)
495
+
496
+ im.append(tmp[0:N])
497
+
498
+ n+=N
235
499
 
500
+ return self.backend.bk_concat([result]+im,axis=0)
501
+ else:
502
+ return result
503
+
236
504
  def anafast(self,im,map2=None,nest=False,spin=2):
237
505
 
238
506
  """The `anafast` function computes the L1 and L2 norm power spectra.
@@ -253,38 +521,41 @@ class alm():
253
521
  ordered as TT, EE, BB, TE, EB.TBanafast function computes L1 and L2 norm powerspctra.
254
522
 
255
523
  """
524
+ i_im=self.backend.bk_cast(im)
525
+ if map2 is not None:
526
+ i_map2=self.backend.bk_cast(map2)
527
+
256
528
  doT=True
257
- if len(im.shape)==1: # nopol
258
- nside=int(np.sqrt(im.shape[0]//12))
529
+ if len(i_im.shape)==1: # nopol
530
+ nside=int(np.sqrt(i_im.shape[0]//12))
259
531
  else:
260
- if im.shape[0]==2:
532
+ if i_im.shape[0]==2:
261
533
  doT=False
262
-
263
- nside=int(np.sqrt(im.shape[1]//12))
534
+ nside=int(np.sqrt(i_im.shape[1]//12))
264
535
 
265
- th,ph=hp.pix2ang(nside,np.arange(12*nside*nside))
266
-
536
+ self.shift_ph(nside)
537
+
267
538
  if doT: # nopol
539
+ if len(i_im.shape)==2: # pol
540
+ l_im=i_im[0]
541
+ if map2 is not None:
542
+ l_map2=i_map2[0]
543
+ else:
544
+ l_im=i_im
545
+ if map2 is not None:
546
+ l_map2=i_map2
547
+
268
548
  if nest:
269
549
  idx=hp.ring2nest(nside,np.arange(12*nside**2))
270
- if len(im.shape)==1: # nopol
271
- ft_im=self.comp_tf(self.backend.bk_complex(self.backend.bk_gather(im,idx),0*im),ph)
272
- if map2 is not None:
273
- ft_im2=self.comp_tf(self.backend.bk_complex(self.backend.bk_gather(map2,idx),0*im),ph)
274
- else:
275
- ft_im=self.comp_tf(self.backend.bk_complex(self.backend.bk_gather(im[0],idx),0*im[0]),ph)
550
+ if len(i_im.shape)==1: # nopol
551
+ ft_im=self.comp_tf(self.backend.bk_gather(l_im,idx),nside,realfft=True)
276
552
  if map2 is not None:
277
- ft_im2=self.comp_tf(self.backend.bk_complex(self.backend.bk_gather(map2[0],idx),0*im[0]),ph)
553
+ ft_im2=self.comp_tf(self.backend.bk_gather(l_map2,idx),nside,realfft=True)
278
554
  else:
279
- if len(im.shape)==1: # nopol
280
- ft_im=self.comp_tf(self.backend.bk_complex(im,0*im),ph)
281
- if map2 is not None:
282
- ft_im2=self.comp_tf(self.backend.bk_complex(map2,0*im),ph)
283
- else:
284
- ft_im=self.comp_tf(self.backend.bk_complex(im[0],0*im[0]),ph)
285
- if map2 is not None:
286
- ft_im2=self.comp_tf(self.backend.bk_complex(map2[0],0*im[0]),ph)
287
-
555
+ ft_im=self.comp_tf(l_im,nside,realfft=True)
556
+ if map2 is not None:
557
+ ft_im2=self.comp_tf(l_map2,nside,realfft=True)
558
+
288
559
  lth=self.ring_th(nside)
289
560
 
290
561
  co_th=np.cos(lth)
@@ -293,33 +564,35 @@ class alm():
293
564
 
294
565
  cl2=None
295
566
  cl2_L1=None
296
-
297
- if len(im.shape)==2: # nopol
567
+ dt2=0
568
+ dt3=0
569
+ dt4=0
570
+ if len(i_im.shape)==2: # nopol
298
571
 
299
572
  self.init_Ys(spin,nside)
300
573
 
301
574
  if nest:
302
575
  idx=hp.ring2nest(nside,np.arange(12*nside**2))
303
- l_Q=self.backend.bk_gather(im[int(doT)],idx)
304
- l_U=self.backend.bk_gather(im[1+int(doT)],idx)
305
- ft_im_Pp=self.comp_tf(self.backend.bk_complex(l_Q,l_U),ph)
306
- ft_im_Pm=self.comp_tf(self.backend.bk_complex(l_Q,-l_U),ph)
576
+ l_Q=self.backend.bk_gather(i_im[int(doT)],idx)
577
+ l_U=self.backend.bk_gather(i_im[1+int(doT)],idx)
578
+ ft_im_Pp=self.comp_tf(self.backend.bk_complex(l_Q,l_U),nside)
579
+ ft_im_Pm=self.comp_tf(self.backend.bk_complex(l_Q,-l_U),nside)
307
580
  if map2 is not None:
308
- l_Q=self.backend.bk_gather(map2[int(doT)],idx)
309
- l_U=self.backend.bk_gather(map2[1+int(doT)],idx)
310
- ft_im2_Pp=self.comp_tf(self.backend.bk_complex(l_Q,l_U),ph)
311
- ft_im2_Pm=self.comp_tf(self.backend.bk_complex(l_Q,-l_U),ph)
581
+ l_Q=self.backend.bk_gather(i_map2[int(doT)],idx)
582
+ l_U=self.backend.bk_gather(i_map2[1+int(doT)],idx)
583
+ ft_im2_Pp=self.comp_tf(self.backend.bk_complex(l_Q,l_U),nside)
584
+ ft_im2_Pm=self.comp_tf(self.backend.bk_complex(l_Q,-l_U),nside)
312
585
  else:
313
- ft_im_Pp=self.comp_tf(self.backend.bk_complex(im[int(doT)],im[1+int(doT)]),ph)
314
- ft_im_Pm=self.comp_tf(self.backend.bk_complex(im[int(doT)],-im[1+int(doT)]),ph)
586
+ ft_im_Pp=self.comp_tf(self.backend.bk_complex(i_im[int(doT)],i_im[1+int(doT)]),nside)
587
+ ft_im_Pm=self.comp_tf(self.backend.bk_complex(i_im[int(doT)],-i_im[1+int(doT)]),nside)
315
588
  if map2 is not None:
316
- ft_im2_Pp=self.comp_tf(self.backend.bk_complex(map2[int(doT)],map2[1+int(doT)]),ph)
317
- ft_im2_Pm=self.comp_tf(self.backend.bk_complex(map2[int(doT)],-map2[1+int(doT)]),ph)
589
+ ft_im2_Pp=self.comp_tf(self.backend.bk_complex(i_map2[int(doT)],i_map2[1+int(doT)]),nside)
590
+ ft_im2_Pm=self.comp_tf(self.backend.bk_complex(i_map2[int(doT)],-i_map2[1+int(doT)]),nside)
318
591
 
319
592
  for m in range(lmax+1):
320
593
 
321
- plm=self.compute_legendre_m(co_th,m,3*nside-1)/(12*nside**2)
322
-
594
+ plm=self.compute_legendre_m(co_th,m,3*nside-1,nside)/(12*nside**2)
595
+
323
596
  if doT:
324
597
  tmp=self.backend.bk_reduce_sum(plm*ft_im[:,m],1)
325
598
 
@@ -327,8 +600,8 @@ class alm():
327
600
  tmp2=self.backend.bk_reduce_sum(plm*ft_im2[:,m],1)
328
601
  else:
329
602
  tmp2=tmp
330
-
331
- if len(im.shape)==2: # pol
603
+
604
+ if len(i_im.shape)==2: # pol
332
605
  plmp=self.Yp[spin,nside][m]
333
606
  plmm=self.Ym[spin,nside][m]
334
607
 
@@ -398,35 +671,35 @@ class alm():
398
671
 
399
672
  if cl2 is None:
400
673
  cl2=l_cl
401
- cl2_l1=self.backend.bk_L1(l_cl)
402
674
  else:
403
675
  cl2+=2*l_cl
404
- cl2_l1+=2*self.backend.bk_L1(l_cl)
405
-
406
- if len(im.shape)==1: # nopol
407
- cl2=cl2/(2*np.arange(cl2.shape[0])+1)
408
- cl2_l1=cl2_l1/(2*np.arange(cl2.shape[0])+1)
409
- else:
410
- cl2=cl2/np.expand_dims(2*np.arange(cl2.shape[1])+1,0)
411
- cl2_l1=cl2_l1/np.expand_dims(2*np.arange(cl2.shape[1])+1,0)
676
+
677
+ #cl2=cl2*(4*np.pi) #self.backend.bk_sqrt(self.backend.bk_cast(4*np.pi)) #(2*np.arange(cl2.shape[0])+1)))
678
+
679
+ cl2_l1=self.backend.bk_L1(cl2)
680
+
412
681
  return cl2,cl2_l1
413
682
 
414
683
  def map2alm(self,im,nest=False):
415
684
  nside=int(np.sqrt(im.shape[0]//12))
416
- th,ph=hp.pix2ang(nside,np.arange(12*nside*nside))
685
+
686
+ ph=self.shift_ph(nside)
687
+
417
688
  if nest:
418
689
  idx=hp.ring2nest(nside,np.arange(12*nside**2))
419
- ft_im=self.comp_tf(self.backend.bk_complex(self.backend.bk_gather(im,idx),0*im),ph)
690
+ ft_im=self.comp_tf(self.backend.bk_cast(self.backend.bk_gather(im,idx)),nside,realfft=True)
420
691
  else:
421
- ft_im=self.comp_tf(self.backend.bk_complex(im,0*im),ph)
692
+ ft_im=self.comp_tf(self.backend.bk_cast(im),nside,realfft=True)
422
693
 
423
- co_th=np.cos(self.ring_th(nside))
694
+ lth=self.ring_th(nside)
695
+
696
+ co_th=np.cos(lth)
424
697
 
425
698
  lmax=3*nside-1
426
699
 
427
700
  alm=None
428
701
  for m in range(lmax+1):
429
- plm=self.compute_legendre_m(co_th,m,3*nside-1)/(12*nside**2)
702
+ plm=self.compute_legendre_m(co_th,m,3*nside-1,nside)/(12*nside**2)
430
703
 
431
704
  tmp=self.backend.bk_reduce_sum(plm*ft_im[:,m],1)
432
705
  if m==0:
@@ -436,29 +709,71 @@ class alm():
436
709
 
437
710
  return alm
438
711
 
712
+
713
+ def alm2map(self,nside,alm):
714
+
715
+ lth=self.ring_th(nside)
716
+
717
+ co_th=np.cos(lth)
718
+
719
+ ft_im=[]
720
+
721
+ n=0
722
+
723
+ lmax=3*nside-1
724
+
725
+ for m in range(lmax+1):
726
+ plm=self.compute_legendre_m(co_th,m,3*nside-1,nside)/(12*nside**2)
727
+
728
+ print(alm[n:n+lmax-m+1].shape,plm.shape)
729
+ ft_im.append(self.backend.bk_reduce_sum(self.backend.bk_reshape(alm[n:n+lmax-m+1],[lmax-m+1,1])*plm,0))
730
+
731
+ n=n+lmax-m+1
732
+
733
+ return self.backend.bk_reshape(self.backend.bk_concat(ft_im,0),[lmax+1,4*nside-1])
734
+
735
+
736
+ if nest:
737
+ idx=hp.ring2nest(nside,np.arange(12*nside**2))
738
+ ft_im=self.comp_tf(self.backend.bk_cast(self.backend.bk_gather(im,idx)),nside,realfft=True)
739
+ else:
740
+ ft_im=self.comp_tf(self.backend.bk_cast(im),nside,realfft=True)
741
+
742
+
743
+ lmax=3*nside-1
744
+
745
+ alm=None
746
+ for m in range(lmax+1):
747
+ plm=self.compute_legendre_m(co_th,m,3*nside-1,nside)/(12*nside**2)
748
+
749
+ tmp=self.backend.bk_reduce_sum(plm*ft_im[:,m],1)
750
+ if m==0:
751
+ alm=tmp
752
+ else:
753
+ alm=self.backend.bk_concat([alm,tmp],axis=0)
754
+
755
+ return o_map
756
+
439
757
  def map2alm_spin(self,im_Q,im_U,spin=2,nest=False):
440
758
 
441
759
  if spin==0:
442
760
  return self.map2alm(im_Q,nest=nest),self.map2alm(im_U,nest=nest)
443
761
 
444
-
445
762
  nside=int(np.sqrt(im_Q.shape[0]//12))
446
- th,ph=hp.pix2ang(nside,np.arange(12*nside*nside))
763
+
764
+ lth=self.ring_th(nside)
447
765
 
448
- self.init_Ys(spin,nside)
766
+ co_th=np.cos(lth)
449
767
 
450
768
  if nest:
451
769
  idx=hp.ring2nest(nside,np.arange(12*nside**2))
452
770
  l_Q=self.backend.bk_gather(im_Q,idx)
453
771
  l_U=self.backend.bk_gather(im_U,idx)
454
- ft_im_1=self.comp_tf(self.backend.bk_complex(l_Q,l_U),ph)
455
- ft_im_2=self.comp_tf(self.backend.bk_complex(l_Q,-l_U),ph)
772
+ ft_im_1=self.comp_tf(self.backend.bk_complex(l_Q,l_U),nside)
773
+ ft_im_2=self.comp_tf(self.backend.bk_complex(l_Q,-l_U),nside)
456
774
  else:
457
- ft_im_1=self.comp_tf(self.backend.bk_complex(im_Q,im_U),ph)
458
- ft_im_2=self.comp_tf(self.backend.bk_complex(im_Q,-im_U),ph)
459
-
460
- #co_th=np.cos(self.ring_th[nside])
461
- #si_th=np.sin(self.ring_th[nside])
775
+ ft_im_1=self.comp_tf(self.backend.bk_complex(im_Q,im_U),nside)
776
+ ft_im_2=self.comp_tf(self.backend.bk_complex(im_Q,-im_U),nside)
462
777
 
463
778
  lmax=3*nside-1
464
779
 
foscat/alm_tools.py CHANGED
@@ -9,184 +9,3 @@ class alm_tools():
9
9
  def __init__(self):
10
10
  pass
11
11
 
12
- @staticmethod
13
- def gen_recfac(l_max, m):
14
- """
15
- Generate recursion factors used to compute the Ylm of degree m for all l in m <= l <= l_max.
16
-
17
- Parameters:
18
- l_max (int): Maximum degree l.
19
- m (int): Degree m.
20
-
21
- Returns:
22
- np.ndarray: Recursion factors as a 2D array of shape (2, l_max + 1).
23
- """
24
- recfac = np.zeros((2, l_max + 1), dtype=np.float64)
25
- fm2 = float(m)**2
26
-
27
- for l in range(m, l_max + 1):
28
- fl2 = float(l + 1)**2
29
- recfac[0, l] = np.sqrt((4.0 * fl2 - 1.0) / (fl2 - fm2))
30
-
31
- recfac[1, m:l_max + 1] = 1.0 / recfac[0, m:l_max + 1]
32
-
33
- return recfac
34
-
35
- @staticmethod
36
- def gen_recfac_spin(l_max, m, spin):
37
- """
38
- Generate recursion factors for spin-weighted spherical harmonics.
39
-
40
- Parameters:
41
- l_max (int): Maximum degree l.
42
- m (int): Degree m.
43
- spin (int): Spin weight.
44
-
45
- Returns:
46
- np.ndarray: Recursion factors as a 2D array of shape (2, l_max + 1).
47
- """
48
- recfac_spin = np.zeros((2, l_max + 1), dtype=np.float64)
49
- fm2 = float(m)**2
50
- s2 = float(spin)**2
51
-
52
- for l in range(m, l_max + 1):
53
- fl2 = float(l + 1)**2
54
- recfac_spin[0, l] = np.sqrt((4.0 * fl2 - 1.0) / (fl2 - fm2))
55
-
56
- recfac_spin[1, m:l_max + 1] = (1.0 - s2 / (float(m) + 1.0)**2) / recfac_spin[0, m:l_max + 1]
57
-
58
- return recfac_spin
59
-
60
- @staticmethod
61
- def gen_lamfac(l_max):
62
- """
63
- Generate lambda factors for spherical harmonics.
64
-
65
- Parameters:
66
- l_max (int): Maximum degree l.
67
-
68
- Returns:
69
- np.ndarray: Lambda factors as a 1D array of size l_max + 1.
70
- """
71
- lamfac = np.zeros(l_max + 1, dtype=np.float64)
72
-
73
- for l in range(1, l_max + 1):
74
- lamfac[l] = np.sqrt(2.0 * l + 1.0)
75
-
76
- return lamfac
77
-
78
- @staticmethod
79
- def gen_lamfac_der(l_max):
80
- """
81
- Generate the derivatives of lambda factors.
82
-
83
- Parameters:
84
- l_max (int): Maximum degree l.
85
-
86
- Returns:
87
- np.ndarray: Lambda factor derivatives as a 1D array of size l_max + 1.
88
- """
89
- lamfac_der = np.zeros(l_max + 1, dtype=np.float64)
90
-
91
- for l in range(1, l_max + 1):
92
- lamfac_der[l] = (2.0 * l + 1.0) / np.sqrt(2.0 * l + 1.0)
93
-
94
- return lamfac_der
95
-
96
- @staticmethod
97
- def gen_mfac(m_max):
98
- """
99
- Generate m factors for spherical harmonics.
100
-
101
- Parameters:
102
- m_max (int): Maximum degree m.
103
-
104
- Returns:
105
- np.ndarray: M factors as a 1D array of size m_max + 1.
106
- """
107
- mfac = np.zeros(m_max + 1, dtype=np.float64)
108
-
109
- for m in range(1, m_max + 1):
110
- mfac[m] = np.sqrt(2.0 * m)
111
-
112
- return mfac
113
-
114
- @staticmethod
115
- def gen_mfac_spin(m_max, spin):
116
- """
117
- Generate m factors for spin-weighted spherical harmonics.
118
-
119
- Parameters:
120
- m_max (int): Maximum degree m.
121
- spin (int): Spin weight.
122
-
123
- Returns:
124
- np.ndarray: Spin-weighted m factors as a 1D array of size m_max + 1.
125
- """
126
- mfac_spin = np.zeros(m_max + 1, dtype=np.float64)
127
-
128
- for m in range(1, m_max + 1):
129
- mfac_spin[m] = np.sqrt(2.0 * m) * (1.0 - spin**2 / (m + 1)**2)
130
-
131
- return mfac_spin
132
-
133
- @staticmethod
134
- def compute_lam_mm(l_max, m):
135
- """
136
- Compute lambda values for specific m.
137
-
138
- Parameters:
139
- l_max (int): Maximum degree l.
140
- m (int): Degree m.
141
-
142
- Returns:
143
- np.ndarray: Lambda values as a 1D array of size l_max + 1.
144
- """
145
- lam_mm = np.zeros(l_max + 1, dtype=np.float64)
146
-
147
- for l in range(m, l_max + 1):
148
- lam_mm[l] = (2.0 * l + 1.0) * (1.0 - (m / (l + 1.0))**2)
149
-
150
- return lam_mm
151
-
152
- @staticmethod
153
- def do_lam_lm(l_max, m):
154
- """
155
- Perform computations for lambda values for all l, m.
156
-
157
- Parameters:
158
- l_max (int): Maximum degree l.
159
- m (int): Degree m.
160
-
161
- Returns:
162
- np.ndarray: Computed lambda values as a 2D array of size (l_max + 1, l_max + 1).
163
- """
164
- lam_lm = np.zeros((l_max + 1, l_max + 1), dtype=np.float64)
165
-
166
- for l in range(m, l_max + 1):
167
- for mp in range(m, l + 1):
168
- lam_lm[l, mp] = (2.0 * l + 1.0) * (1.0 - (mp / (l + 1.0))**2)
169
-
170
- return lam_lm
171
-
172
- @staticmethod
173
- def do_lam_lm_spin(l_max, m, spin):
174
- """
175
- Perform computations for spin-weighted lambda values for all l, m.
176
-
177
- Parameters:
178
- l_max (int): Maximum degree l.
179
- m (int): Degree m.
180
- spin (int): Spin weight.
181
-
182
- Returns:
183
- np.ndarray: Computed spin-weighted lambda values as a 2D array of size (l_max + 1, l_max + 1).
184
- """
185
- lam_lm_spin = np.zeros((l_max + 1, l_max + 1), dtype=np.float64)
186
-
187
- for l in range(m, l_max + 1):
188
- for mp in range(m, l + 1):
189
- lam_lm_spin[l, mp] = (2.0 * l + 1.0) * (1.0 - spin**2 / (mp + 1.0)**2)
190
-
191
- return lam_lm_spin
192
-
foscat/backend.py CHANGED
@@ -879,7 +879,10 @@ class foscat_backend:
879
879
  def bk_repeat(self, data, nn, axis=0):
880
880
  return self.backend.repeat(data, nn, axis=axis)
881
881
 
882
- def bk_tile(self, data, nn, axis=0):
882
+ def bk_tile(self, data, nn,axis=0):
883
+ if self.BACKEND == self.TENSORFLOW:
884
+ return self.backend.tile(data, [nn])
885
+
883
886
  return self.backend.tile(data, nn)
884
887
 
885
888
  def bk_roll(self, data, nn, axis=0):
@@ -952,6 +955,14 @@ class foscat_backend:
952
955
  if self.BACKEND == self.NUMPY:
953
956
  return data[idx]
954
957
 
958
+ def bk_reverse(self, data,axis=0):
959
+ if self.BACKEND == self.TENSORFLOW:
960
+ return self.backend.reverse(data,axis=[axis])
961
+ if self.BACKEND == self.TORCH:
962
+ return self.backend.reverse(data,axis=axis)
963
+ if self.BACKEND == self.NUMPY:
964
+ return np.reverse(data,axis=axis)
965
+
955
966
  def bk_fft(self, data):
956
967
  if self.BACKEND == self.TENSORFLOW:
957
968
  return self.backend.signal.fft(data)
@@ -960,6 +971,23 @@ class foscat_backend:
960
971
  if self.BACKEND == self.NUMPY:
961
972
  return self.backend.fft.fft(data)
962
973
 
974
+ def bk_rfft(self, data):
975
+ if self.BACKEND == self.TENSORFLOW:
976
+ return self.backend.signal.rfft(data)
977
+ if self.BACKEND == self.TORCH:
978
+ return self.backend.rfft(data)
979
+ if self.BACKEND == self.NUMPY:
980
+ return self.backend.fft.rfft(data)
981
+
982
+
983
+ def bk_irfft(self, data):
984
+ if self.BACKEND == self.TENSORFLOW:
985
+ return self.backend.signal.irfft(data)
986
+ if self.BACKEND == self.TORCH:
987
+ return self.backend.irfft(data)
988
+ if self.BACKEND == self.NUMPY:
989
+ return self.backend.fft.irfft(data)
990
+
963
991
  def bk_conjugate(self, data):
964
992
 
965
993
  if self.BACKEND == self.TENSORFLOW:
foscat/scat_cov.py CHANGED
@@ -2488,15 +2488,16 @@ class funct(FOC.FoCUS):
2488
2488
  )
2489
2489
 
2490
2490
  def eval(
2491
- self,
2492
- image1,
2493
- image2=None,
2494
- mask=None,
2495
- norm=None,
2496
- Auto=True,
2497
- calc_var=False,
2498
- cmat=None,
2499
- cmat2=None,
2491
+ self,
2492
+ image1,
2493
+ image2=None,
2494
+ mask=None,
2495
+ norm=None,
2496
+ Auto=True,
2497
+ calc_var=False,
2498
+ cmat=None,
2499
+ cmat2=None,
2500
+ out_nside=None
2500
2501
  ):
2501
2502
  """
2502
2503
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2708,6 +2709,8 @@ class funct(FOC.FoCUS):
2708
2709
 
2709
2710
  if return_data:
2710
2711
  s0 = I1
2712
+ if out_nside is not None:
2713
+ s0 = self.backend.bk_reduce_mean(self.backend.bk_reshape(s0,[s0.shape[0],12*out_nside**2,(nside//out_nside)**2]),2)
2711
2714
  else:
2712
2715
  if not cross:
2713
2716
  s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
@@ -2778,6 +2781,12 @@ class funct(FOC.FoCUS):
2778
2781
  if return_data:
2779
2782
  if S2 is None:
2780
2783
  S2 = {}
2784
+ if out_nside is not None and out_nside<nside_j3:
2785
+ s2 = self.backend.bk_reduce_mean(
2786
+ self.backend.bk_reshape(s2,[s2.shape[0],
2787
+ 12*out_nside**2,
2788
+ (nside_j3//out_nside)**2,
2789
+ s2.shape[2]]),2)
2781
2790
  S2[j3] = s2
2782
2791
  else:
2783
2792
  if norm == "auto": # Normalize S2
@@ -2818,6 +2827,12 @@ class funct(FOC.FoCUS):
2818
2827
  if return_data:
2819
2828
  if S1 is None:
2820
2829
  S1 = {}
2830
+ if out_nside is not None and out_nside<nside_j3:
2831
+ s1 = self.backend.bk_reduce_mean(
2832
+ self.backend.bk_reshape(s1,[s1.shape[0],
2833
+ 12*out_nside**2,
2834
+ (nside_j3//out_nside)**2,
2835
+ s1.shape[2]]),2)
2821
2836
  S1[j3] = s1
2822
2837
  else:
2823
2838
  ### Normalize S1
@@ -2904,6 +2919,12 @@ class funct(FOC.FoCUS):
2904
2919
  if return_data:
2905
2920
  if S2 is None:
2906
2921
  S2 = {}
2922
+ if out_nside is not None and out_nside<nside_j3:
2923
+ s2 = self.backend.bk_reduce_mean(
2924
+ self.backend.bk_reshape(s2,[s2.shape[0],
2925
+ 12*out_nside**2,
2926
+ (nside_j3//out_nside)**2,
2927
+ s2.shape[2]]),2)
2907
2928
  S2[j3] = s2
2908
2929
  else:
2909
2930
  ### Normalize S2_cross
@@ -2949,6 +2970,12 @@ class funct(FOC.FoCUS):
2949
2970
  if return_data:
2950
2971
  if S1 is None:
2951
2972
  S1 = {}
2973
+ if out_nside is not None and out_nside<nside_j3:
2974
+ s1 = self.backend.bk_reduce_mean(
2975
+ self.backend.bk_reshape(s1,[s1.shape[0],
2976
+ 12*out_nside**2,
2977
+ (nside_j3//out_nside)**2,
2978
+ s1.shape[2]]),2)
2952
2979
  S1[j3] = s1
2953
2980
  else:
2954
2981
  ### Normalize S1
@@ -2979,6 +3006,7 @@ class funct(FOC.FoCUS):
2979
3006
  M2convPsi_dic = {}
2980
3007
 
2981
3008
  ###### S3
3009
+ nside_j2=nside_j3
2982
3010
  for j2 in range(0, j3 + 1): # j2 <= j3
2983
3011
  if return_data:
2984
3012
  if S4[j3] is None:
@@ -3013,6 +3041,13 @@ class funct(FOC.FoCUS):
3013
3041
  if return_data:
3014
3042
  if S3[j3] is None:
3015
3043
  S3[j3] = {}
3044
+ if out_nside is not None and out_nside<nside_j2:
3045
+ s3 = self.backend.bk_reduce_mean(
3046
+ self.backend.bk_reshape(s3,[s3.shape[0],
3047
+ 12*out_nside**2,
3048
+ (nside_j2//out_nside)**2,
3049
+ s3.shape[2],
3050
+ s3.shape[3]]),2)
3016
3051
  S3[j3][j2] = s3
3017
3052
  else:
3018
3053
  ### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3095,6 +3130,19 @@ class funct(FOC.FoCUS):
3095
3130
  if S3[j3] is None:
3096
3131
  S3[j3] = {}
3097
3132
  S3P[j3] = {}
3133
+ if out_nside is not None and out_nside<nside_j2:
3134
+ s3 = self.backend.bk_reduce_mean(
3135
+ self.backend.bk_reshape(s3,[s3.shape[0],
3136
+ 12*out_nside**2,
3137
+ (nside_j2//out_nside)**2,
3138
+ s3.shape[2],
3139
+ s3.shape[3]]),2)
3140
+ s3p = self.backend.bk_reduce_mean(
3141
+ self.backend.bk_reshape(s3p,[s3.shape[0],
3142
+ 12*out_nside**2,
3143
+ (nside_j2//out_nside)**2,
3144
+ s3.shape[2],
3145
+ s3.shape[3]]),2)
3098
3146
  S3[j3][j2] = s3
3099
3147
  S3P[j3][j2] = s3p
3100
3148
  else:
@@ -3154,6 +3202,7 @@ class funct(FOC.FoCUS):
3154
3202
  ) # Add a dimension for NS3
3155
3203
 
3156
3204
  ##### S4
3205
+ nside_j1=nside_j2
3157
3206
  for j1 in range(0, j2 + 1): # j1 <= j2
3158
3207
  ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
3159
3208
  if not cross:
@@ -3179,6 +3228,14 @@ class funct(FOC.FoCUS):
3179
3228
  if return_data:
3180
3229
  if S4[j3][j2] is None:
3181
3230
  S4[j3][j2] = {}
3231
+ if out_nside is not None and out_nside<nside_j1:
3232
+ s4 = self.backend.bk_reduce_mean(
3233
+ self.backend.bk_reshape(s4,[s4.shape[0],
3234
+ 12*out_nside**2,
3235
+ (nside_j1//out_nside)**2,
3236
+ s4.shape[2],
3237
+ s4.shape[3],
3238
+ s4.shape[4]]),2)
3182
3239
  S4[j3][j2][j1] = s4
3183
3240
  else:
3184
3241
  ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3248,6 +3305,14 @@ class funct(FOC.FoCUS):
3248
3305
  if return_data:
3249
3306
  if S4[j3][j2] is None:
3250
3307
  S4[j3][j2] = {}
3308
+ if out_nside is not None and out_nside<nside_j1:
3309
+ s4 = self.backend.bk_reduce_mean(
3310
+ self.backend.bk_reshape(s4,[s4.shape[0],
3311
+ 12*out_nside**2,
3312
+ (nside_j1//out_nside)**2,
3313
+ s4.shape[2],
3314
+ s4.shape[3],
3315
+ s4.shape[4]]),2)
3251
3316
  S4[j3][j2][j1] = s4
3252
3317
  else:
3253
3318
  ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3292,7 +3357,9 @@ class funct(FOC.FoCUS):
3292
3357
  ],
3293
3358
  axis=2,
3294
3359
  ) # Add a dimension for NS4
3295
-
3360
+ nside_j1=nside_j1 // 2
3361
+ nside_j2=nside_j2 // 2
3362
+
3296
3363
  ###### Reshape for next iteration on j3
3297
3364
  ### Image I1,
3298
3365
  # downscale the I1 [Nbatch, Npix_j3]
@@ -3659,9 +3726,14 @@ class funct(FOC.FoCUS):
3659
3726
  if x.S3P is not None:
3660
3727
  nval += self.backend.bk_size(x.S3P)
3661
3728
  result /= self.backend.bk_cast(nval)
3729
+ return result
3662
3730
  else:
3663
- return self.backend.bk_reduce_sum(x)
3664
- return result
3731
+ if sigma is None:
3732
+ tmp=x-y
3733
+ else:
3734
+ tmp=(x-y)/sigma
3735
+ # do abs in case of complex values
3736
+ return self.backend.bk_abs(self.backend.bk_reduce_mean(self.backend.bk_square(tmp)))
3665
3737
 
3666
3738
  def reduce_sum(self, x):
3667
3739
 
foscat/scat_cov_map.py CHANGED
@@ -24,10 +24,10 @@ class funct(scat.funct):
24
24
  super().__init__(return_data=True, *args, **kwargs)
25
25
 
26
26
  def eval(
27
- self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False
27
+ self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False,out_nside=None
28
28
  ):
29
29
  r = super().eval(
30
- image1, image2=image2, mask=mask, norm=norm, Auto=Auto, calc_var=calc_var
30
+ image1, image2=image2, mask=mask, norm=norm, Auto=Auto, calc_var=calc_var,out_nside=out_nside
31
31
  )
32
32
  return scat_cov_map(
33
33
  r.S2, r.S0, r.S3, r.S4, S1=r.S1, S3P=r.S3P, backend=r.backend
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.1
2
2
  Name: foscat
3
- Version: 3.4.0
3
+ Version: 3.6.1
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -1,27 +1,27 @@
1
1
  foscat/CNN.py,sha256=j0F2a4Xf3LijhyD_WVZ6Eg_IjGuXw3ddH6Iudj1xVaw,4874
2
2
  foscat/CircSpline.py,sha256=DjP1gy88cnXu2O21ww_lNnsHAHXc3OAWk_8ey84yicg,4053
3
- foscat/FoCUS.py,sha256=gq2KcT3GN16Ypcy-9OOHo_OVfjDxfWKV0bU2dvzScWI,101774
3
+ foscat/FoCUS.py,sha256=YoZqPiDLQanUXGgN4RyN6S4_F2rHn273P6QQjq0hlBU,101774
4
4
  foscat/GCNN.py,sha256=5RV-FKuvqbD-k99TwiM4CttM2LMZE21WD0IK0j5Mkko,7599
5
5
  foscat/Softmax.py,sha256=aBLQauoG0q2SJYPotV6U-cxAhsJcspWHNRWdnA_nAiQ,2854
6
6
  foscat/Spline1D.py,sha256=a5Jb8I9tb8y20iM8W-z6iZsIqDFByRp6eZdChpmuI5k,3010
7
7
  foscat/Synthesis.py,sha256=3_Lq5-gUM-WmO2h15kajMES8XjRo2BGseoxvTLW_xEc,13626
8
8
  foscat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- foscat/alm.py,sha256=Jte3HBEoLahH8sIvEZ-hQKfY1cpRdYKITWcqAe-geBA,19451
10
- foscat/alm_tools.py,sha256=zI6r7VWt4oCEFoHyK3uEaIsLqazYfDg6MblsQwzozAs,5402
11
- foscat/backend.py,sha256=Hh4FbjVgzlbyG-HfagULk6StQo5eg-Z_vkHQ9XIVRGE,39182
9
+ foscat/alm.py,sha256=IIvYOz0kAtufhlBzS_QOx891-Y30ARmofyqomOFmJDY,29335
10
+ foscat/alm_tools.py,sha256=BMObyIMF3_v04JdvMWBMMaF2F8s2emvDKmirKDnHWDA,387
11
+ foscat/backend.py,sha256=bN-b0CWcJXVsLCyqMhJACom0JlSStgrMWyD0uB6HqoU,40218
12
12
  foscat/backend_tens.py,sha256=9Dp136m9frkclkwifJQLLbIpl3ETI3_txdPUZcKfuMw,1618
13
13
  foscat/loss_backend_tens.py,sha256=dCOVN6faDtIpN3VO78HTmYP2i5fnFAf-Ddy5qVBlGrM,1783
14
14
  foscat/loss_backend_torch.py,sha256=k3z18Dj3SaLKK6ZIKcm7GO4U_YKYVP6LtHG1aIbxkYk,1627
15
15
  foscat/scat.py,sha256=qGYiBIysPt65MdmF07WWA4piVlTfA9-lFDTaicnqC2w,72822
16
16
  foscat/scat1D.py,sha256=W5Uu6wdQ4ZsFKXpof0f1OBl-1wjJmW7ruvddRWxe7uM,53726
17
17
  foscat/scat2D.py,sha256=boKj0ASqMMSy7uQLK6hPniG87m3hZGJBYBiq5v8F9IQ,532
18
- foscat/scat_cov.py,sha256=6dItWbxQG7t9_AALDMvaidJ3s0d5ospMR5OS5AFbedY,145576
18
+ foscat/scat_cov.py,sha256=gQvGC2s93oDurBDW0A6mGyzzMETjYAuBsBsDOI-ZqUY,150642
19
19
  foscat/scat_cov1D.py,sha256=XOxsZZ5TYq8f34i2tUgIfzyaqaTDlICB3HzD2l_puro,531
20
20
  foscat/scat_cov2D.py,sha256=3gn6xjKvfKsyHJoPfYIu8q9LLVAbU3tsiS2l1LAJ0XM,531
21
- foscat/scat_cov_map.py,sha256=0wTRo4Nc7rYfI09RI2mh2bYixoukt5lrvAXR6wa9kjA,2744
21
+ foscat/scat_cov_map.py,sha256=9Yymbr1MxUNY5nJd9kIEEHt1x2IoOjc0EW4kkJVtmQ4,2783
22
22
  foscat/scat_cov_map2D.py,sha256=FqF45FBcoiQbvuVsrLWUIPRUc95GsKsrnH6fKzB3GlE,2841
23
- foscat-3.4.0.dist-info/LICENCE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
24
- foscat-3.4.0.dist-info/METADATA,sha256=obU7dCa9Mbe6SZ4ptdpHTB3bGtbPeW0GlX0aKIBH9Yg,7216
25
- foscat-3.4.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
26
- foscat-3.4.0.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
27
- foscat-3.4.0.dist-info/RECORD,,
23
+ foscat-3.6.1.dist-info/LICENCE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
24
+ foscat-3.6.1.dist-info/METADATA,sha256=6t347EDYxBbM5qhlC5T6pwbJDBwvAnJ5VJnNppblC6A,7216
25
+ foscat-3.6.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
26
+ foscat-3.6.1.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
27
+ foscat-3.6.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5