foscat 3.0.9__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -1,766 +1,1468 @@
1
- import numpy as np
1
+ import os
2
+ import sys
3
+
2
4
  import healpy as hp
3
- import os, sys
4
- import foscat.backend as bk
5
+ import numpy as np
5
6
  from scipy.interpolate import griddata
6
7
 
7
- TMPFILE_VERSION='V2_6'
8
+ import foscat.backend as bk
8
9
 
9
- class FoCUS:
10
- def __init__(self,
11
- NORIENT=4,
12
- LAMBDA=1.2,
13
- KERNELSZ=3,
14
- slope=1.0,
15
- all_type='float64',
16
- nstep_max=16,
17
- padding='SAME',
18
- gpupos=0,
19
- mask_thres=None,
20
- mask_norm=False,
21
- OSTEP=0,
22
- isMPI=False,
23
- TEMPLATE_PATH='data',
24
- BACKEND='tensorflow',
25
- use_2D=False,
26
- return_data=False,
27
- JmaxDelta=0,
28
- DODIV=False,
29
- InitWave=None,
30
- mpi_size=1,
31
- mpi_rank=0):
10
+ TMPFILE_VERSION = "V4_0"
32
11
 
12
+
13
+ class FoCUS:
14
+ def __init__(
15
+ self,
16
+ NORIENT=4,
17
+ LAMBDA=1.2,
18
+ KERNELSZ=3,
19
+ slope=1.0,
20
+ all_type="float64",
21
+ nstep_max=16,
22
+ padding="SAME",
23
+ gpupos=0,
24
+ mask_thres=None,
25
+ mask_norm=False,
26
+ OSTEP=0,
27
+ isMPI=False,
28
+ TEMPLATE_PATH="data",
29
+ BACKEND="tensorflow",
30
+ use_2D=False,
31
+ use_1D=False,
32
+ return_data=False,
33
+ JmaxDelta=0,
34
+ DODIV=False,
35
+ InitWave=None,
36
+ silent=False,
37
+ mpi_size=1,
38
+ mpi_rank=0,
39
+ ):
40
+
41
+ self.__version__ = "3.6.0"
33
42
  # P00 coeff for normalization for scat_cov
34
- self.TMPFILE_VERSION=TMPFILE_VERSION
43
+ self.TMPFILE_VERSION = TMPFILE_VERSION
35
44
  self.P1_dic = None
36
45
  self.P2_dic = None
37
- self.isMPI=isMPI
46
+ self.isMPI = isMPI
38
47
  self.mask_thres = mask_thres
39
48
  self.mask_norm = mask_norm
40
- self.InitWave=InitWave
41
-
42
- self.mpi_size=mpi_size
43
- self.mpi_rank=mpi_rank
44
- self.return_data=return_data
45
-
46
- print('================================================')
47
- print(' START FOSCAT CONFIGURATION')
48
- print('================================================')
49
- sys.stdout.flush()
50
-
51
- self.TEMPLATE_PATH=TEMPLATE_PATH
52
- if os.path.exists(self.TEMPLATE_PATH)==False:
53
- print('The directory %s to store temporary information for FoCUS does not exist: Try to create it'%(self.TEMPLATE_PATH))
49
+ self.InitWave = InitWave
50
+
51
+ self.mpi_size = mpi_size
52
+ self.mpi_rank = mpi_rank
53
+ self.return_data = return_data
54
+ self.silent = silent
55
+
56
+ if not self.silent:
57
+ print("================================================")
58
+ print(" START FOSCAT CONFIGURATION")
59
+ print("================================================")
60
+ sys.stdout.flush()
61
+
62
+ self.TEMPLATE_PATH = TEMPLATE_PATH
63
+ if not os.path.exists(self.TEMPLATE_PATH):
64
+ if not self.silent:
65
+ print(
66
+ "The directory %s to store temporary information for FoCUS does not exist: Try to create it"
67
+ % (self.TEMPLATE_PATH)
68
+ )
54
69
  try:
55
- os.system('mkdir -p %s'%(self.TEMPLATE_PATH))
56
- print('The directory %s is created')
70
+ os.system("mkdir -p %s" % (self.TEMPLATE_PATH))
71
+ if not self.silent:
72
+ print("The directory %s is created")
57
73
  except:
58
- print('Impossible to create the directory %s'%(self.TEMPLATE_PATH))
59
- exit(0)
60
-
61
- self.number_of_loss=0
62
-
63
- self.history=np.zeros([10])
64
- self.nlog=0
65
- self.padding=padding
66
-
67
- if OSTEP!=0:
68
- print('OPTION option is deprecated after version 2.0.6. Please use Jmax option')
69
- JmaxDelta=OSTEP
74
+ if not self.silent:
75
+ print(
76
+ "Impossible to create the directory %s" % (self.TEMPLATE_PATH)
77
+ )
78
+ return None
79
+
80
+ self.number_of_loss = 0
81
+
82
+ self.history = np.zeros([10])
83
+ self.nlog = 0
84
+ self.padding = padding
85
+
86
+ if OSTEP != 0:
87
+ if not self.silent:
88
+ print(
89
+ "OPTION option is deprecated after version 2.0.6. Please use Jmax option"
90
+ )
91
+ JmaxDelta = OSTEP
70
92
  else:
71
- OSTEP=JmaxDelta
72
-
73
- if JmaxDelta<-1:
74
- print('Warning : Jmax can not be smaller than -1')
75
- exit(0)
76
-
77
- self.OSTEP=JmaxDelta
78
- self.use_2D=use_2D
79
-
93
+ OSTEP = JmaxDelta
94
+
95
+ if JmaxDelta < -1:
96
+ if not self.silent:
97
+ print("Warning : Jmax can not be smaller than -1")
98
+ return None
99
+
100
+ self.OSTEP = JmaxDelta
101
+ self.use_2D = use_2D
102
+ self.use_1D = use_1D
103
+
80
104
  if isMPI:
81
105
  from mpi4py import MPI
82
106
 
83
- self.comm= MPI.COMM_WORLD
84
- if all_type=='float32':
85
- self.MPI_ALL_TYPE=MPI.FLOAT
107
+ self.comm = MPI.COMM_WORLD
108
+ if all_type == "float32":
109
+ self.MPI_ALL_TYPE = MPI.FLOAT
86
110
  else:
87
- self.MPI_ALL_TYPE=MPI.DOUBLE
111
+ self.MPI_ALL_TYPE = MPI.DOUBLE
88
112
  else:
89
- self.MPI_ALL_TYPE=None
90
-
91
- self.all_type=all_type
92
- self.BACKEND=BACKEND
93
- self.backend=bk.foscat_backend(BACKEND,
94
- all_type=all_type,
95
- mpi_rank=mpi_rank,
96
- gpupos=gpupos)
97
-
98
- self.all_bk_type=self.backend.all_bk_type
99
- self.all_cbk_type=self.backend.all_cbk_type
100
- self.gpulist=self.backend.gpulist
101
- self.ngpu=self.backend.ngpu
102
- self.rank=mpi_rank
103
-
104
- self.gpupos=(gpupos+mpi_rank)%self.backend.ngpu
105
-
106
- print('============================================================')
107
- print('== ==')
108
- print('== ==')
109
- print('== RUN ON GPU Rank %d : %s =='%(mpi_rank,self.gpulist[self.gpupos%self.ngpu]))
110
- print('== ==')
111
- print('== ==')
112
- print('============================================================')
113
- sys.stdout.flush()
114
-
115
- l_NORIENT=NORIENT
113
+ self.MPI_ALL_TYPE = None
114
+
115
+ self.all_type = all_type
116
+ self.BACKEND = BACKEND
117
+ self.backend = bk.foscat_backend(
118
+ BACKEND,
119
+ all_type=all_type,
120
+ mpi_rank=mpi_rank,
121
+ gpupos=gpupos,
122
+ silent=self.silent,
123
+ )
124
+
125
+ self.all_bk_type = self.backend.all_bk_type
126
+ self.all_cbk_type = self.backend.all_cbk_type
127
+ self.gpulist = self.backend.gpulist
128
+ self.ngpu = self.backend.ngpu
129
+ self.rank = mpi_rank
130
+
131
+ self.gpupos = (gpupos + mpi_rank) % self.backend.ngpu
132
+
133
+ if not self.silent:
134
+ print("============================================================")
135
+ print("== ==")
136
+ print("== ==")
137
+ print(
138
+ "== RUN ON GPU Rank %d : %s =="
139
+ % (mpi_rank, self.gpulist[self.gpupos % self.ngpu])
140
+ )
141
+ print("== ==")
142
+ print("== ==")
143
+ print("============================================================")
144
+ sys.stdout.flush()
145
+
146
+ l_NORIENT = NORIENT
116
147
  if DODIV:
117
- l_NORIENT=NORIENT+2
118
-
119
- self.NORIENT=l_NORIENT
120
- self.LAMBDA=LAMBDA
121
- self.slope=slope
122
-
123
- self.R_off=(KERNELSZ-1)//2
124
- if (self.R_off//2)*2<self.R_off:
125
- self.R_off+=1
126
-
127
- self.ww_Real = {}
128
- self.ww_Imag = {}
129
-
130
- wwc=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
131
- wws=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
132
-
133
- x=np.repeat(np.arange(KERNELSZ)-KERNELSZ//2,KERNELSZ).reshape(KERNELSZ,KERNELSZ)
134
- y=x.T
135
-
136
- if NORIENT==1:
137
- xx=(3/float(KERNELSZ))*LAMBDA*x
138
- yy=(3/float(KERNELSZ))*LAMBDA*y
139
-
140
- if KERNELSZ==5:
141
- #w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
142
- w_smooth=np.exp(-(xx**2+yy**2))
143
- tmp=np.exp(-2*(xx**2+yy**2))-0.25*np.exp(-0.5*(xx**2+yy**2))
148
+ l_NORIENT = NORIENT + 2
149
+
150
+ self.NORIENT = l_NORIENT
151
+ self.LAMBDA = LAMBDA
152
+ self.slope = slope
153
+
154
+ self.R_off = (KERNELSZ - 1) // 2
155
+ if (self.R_off // 2) * 2 < self.R_off:
156
+ self.R_off += 1
157
+
158
+ self.ww_Real = {}
159
+ self.ww_Imag = {}
160
+ self.ww_CNN_Transpose = {}
161
+ self.ww_CNN = {}
162
+ self.X_CNN = {}
163
+ self.Y_CNN = {}
164
+ self.Z_CNN = {}
165
+
166
+ wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
167
+ wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
168
+
169
+ x = np.repeat(np.arange(KERNELSZ) - KERNELSZ // 2, KERNELSZ).reshape(
170
+ KERNELSZ, KERNELSZ
171
+ )
172
+ y = x.T
173
+
174
+ if NORIENT == 1:
175
+ xx = (3 / float(KERNELSZ)) * LAMBDA * x
176
+ yy = (3 / float(KERNELSZ)) * LAMBDA * y
177
+
178
+ if KERNELSZ == 5:
179
+ # w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
180
+ w_smooth = np.exp(-(xx**2 + yy**2))
181
+ tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp(
182
+ -0.5 * (xx**2 + yy**2)
183
+ )
144
184
  else:
145
- w_smooth=np.exp(-0.5*(xx**2+yy**2))
146
- tmp=np.exp(-2*(xx**2+yy**2))-0.25*np.exp(-0.5*(xx**2+yy**2))
147
-
148
- wwc[:,0]=tmp.flatten()-tmp.mean()
149
- tmp=0*w_smooth
150
- wws[:,0]=tmp.flatten()
151
- sigma=np.sqrt((wwc[:,0]**2).mean())
152
- wwc[:,0]/=sigma
153
- wws[:,0]/=sigma
154
-
155
- w_smooth=w_smooth.flatten()
185
+ w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
186
+ tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp(
187
+ -0.5 * (xx**2 + yy**2)
188
+ )
189
+
190
+ wwc[:, 0] = tmp.flatten() - tmp.mean()
191
+ tmp = 0 * w_smooth
192
+ wws[:, 0] = tmp.flatten()
193
+ sigma = np.sqrt((wwc[:, 0] ** 2).mean())
194
+ wwc[:, 0] /= sigma
195
+ wws[:, 0] /= sigma
196
+
197
+ w_smooth = w_smooth.flatten()
156
198
  else:
157
199
  for i in range(NORIENT):
158
- a=i/float(NORIENT)*np.pi
159
- xx=(3/float(KERNELSZ))*LAMBDA*(x*np.cos(a)+y*np.sin(a))
160
- yy=(3/float(KERNELSZ))*LAMBDA*(x*np.sin(a)-y*np.cos(a))
200
+ a = i / float(NORIENT) * np.pi
201
+ xx = (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
202
+ yy = (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
161
203
 
162
- if KERNELSZ==5:
163
- #w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
164
- w_smooth=np.exp(-(xx**2+yy**2))
204
+ if KERNELSZ == 5:
205
+ # w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
206
+ w_smooth = np.exp(-(xx**2 + yy**2))
165
207
  else:
166
- w_smooth=np.exp(-0.5*(xx**2+yy**2))
167
- tmp1=np.cos(yy*np.pi)*w_smooth
168
- tmp2=np.sin(yy*np.pi)*w_smooth
169
-
170
- wwc[:,i]=tmp1.flatten()-tmp1.mean()
171
- wws[:,i]=tmp2.flatten()-tmp2.mean()
172
- sigma=np.sqrt((wwc[:,i]**2).mean())
173
- wwc[:,i]/=sigma
174
- wws[:,i]/=sigma
175
-
176
- if DODIV and i==0:
177
- r=(xx**2+yy**2)
178
- theta=np.arctan2(yy,xx)
179
- theta[KERNELSZ//2,KERNELSZ//2]=0.0
180
- tmp1=r*np.cos(2*theta)*w_smooth
181
- tmp2=r*np.sin(2*theta)*w_smooth
182
-
183
- wwc[:,NORIENT]=tmp1.flatten()-tmp1.mean()
184
- wws[:,NORIENT]=tmp2.flatten()-tmp2.mean()
185
- sigma=np.sqrt((wwc[:,NORIENT]**2).mean())
186
-
187
- wwc[:,NORIENT]/=sigma
188
- wws[:,NORIENT]/=sigma
189
- tmp1=r*np.cos(2*theta+np.pi)
190
- tmp2=r*np.sin(2*theta+np.pi)
191
-
192
- wwc[:,NORIENT+1]=tmp1.flatten()-tmp1.mean()
193
- wws[:,NORIENT+1]=tmp2.flatten()-tmp2.mean()
194
- sigma=np.sqrt((wwc[:,NORIENT+1]**2).mean())
195
- wwc[:,NORIENT+1]/=sigma
196
- wws[:,NORIENT+1]/=sigma
197
-
198
-
199
- w_smooth=w_smooth.flatten()
200
-
201
- self.KERNELSZ=KERNELSZ
202
-
203
- self.Idx_Neighbours={}
204
-
205
- if not self.use_2D:
208
+ w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
209
+ tmp1 = np.cos(yy * np.pi) * w_smooth
210
+ tmp2 = np.sin(yy * np.pi) * w_smooth
211
+
212
+ wwc[:, i] = tmp1.flatten() - tmp1.mean()
213
+ wws[:, i] = tmp2.flatten() - tmp2.mean()
214
+ sigma = np.sqrt((wwc[:, i] ** 2).mean())
215
+ wwc[:, i] /= sigma
216
+ wws[:, i] /= sigma
217
+
218
+ if DODIV and i == 0:
219
+ r = xx**2 + yy**2
220
+ theta = np.arctan2(yy, xx)
221
+ theta[KERNELSZ // 2, KERNELSZ // 2] = 0.0
222
+ tmp1 = r * np.cos(2 * theta) * w_smooth
223
+ tmp2 = r * np.sin(2 * theta) * w_smooth
224
+
225
+ wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
226
+ wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
227
+ sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
228
+
229
+ wwc[:, NORIENT] /= sigma
230
+ wws[:, NORIENT] /= sigma
231
+ tmp1 = r * np.cos(2 * theta + np.pi)
232
+ tmp2 = r * np.sin(2 * theta + np.pi)
233
+
234
+ wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
235
+ wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
236
+ sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
237
+ wwc[:, NORIENT + 1] /= sigma
238
+ wws[:, NORIENT + 1] /= sigma
239
+
240
+ w_smooth = w_smooth.flatten()
241
+ if self.use_1D:
242
+ KERNELSZ = 5
243
+
244
+ self.KERNELSZ = KERNELSZ
245
+
246
+ self.Idx_Neighbours = {}
247
+
248
+ if not self.use_2D and not self.use_1D:
206
249
  self.w_smooth = {}
207
250
  for i in range(nstep_max):
208
- lout=(2**i)
209
- self.ww_Real[lout]=None
251
+ lout = 2**i
252
+ self.ww_Real[lout] = None
253
+
254
+ for i in range(1, 6):
255
+ lout = 2**i
256
+ if not self.silent:
257
+ print("Init Wave ", lout)
210
258
 
211
- for i in range(1,6):
212
- lout=(2**i)
213
- print('Init Wave ',lout)
214
-
215
259
  if self.InitWave is None:
216
- wr,wi,ws,widx=self.init_index(lout)
260
+ wr, wi, ws, widx = self.init_index(lout)
217
261
  else:
218
- wr,wi,ws,widx=self.InitWave(self,lout)
219
-
220
- self.Idx_Neighbours[lout]=1 #self.backend.constant(widx)
221
- self.ww_Real[lout]=wr
222
- self.ww_Imag[lout]=wi
223
- self.w_smooth[lout]=ws
224
- else:
225
- self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
226
- self.ww_RealT={}
227
- self.ww_ImagT={}
228
- self.ww_SmoothT={}
262
+ wr, wi, ws, widx = self.InitWave(self, lout)
263
+
264
+ self.Idx_Neighbours[lout] = 1 # self.backend.constant(widx)
265
+ self.ww_Real[lout] = wr
266
+ self.ww_Imag[lout] = wi
267
+ self.w_smooth[lout] = ws
268
+ elif self.use_1D:
269
+ self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
270
+ self.ww_RealT = {}
271
+ self.ww_ImagT = {}
272
+ self.ww_SmoothT = {}
273
+ if KERNELSZ == 5:
274
+ xx = np.arange(5) - 2
275
+ w = np.exp(-0.25 * (xx) ** 2)
276
+ c = w * np.cos((xx) * np.pi / 2)
277
+ s = w * np.sin((xx) * np.pi / 2)
278
+
279
+ w = w / np.sum(w)
280
+ c = c - np.mean(c)
281
+ s = s - np.mean(s)
282
+ r = np.sum(np.sqrt(c * c + s * s))
283
+ c = c / r
284
+ s = s / r
285
+ self.ww_RealT[1] = self.backend.constant(
286
+ np.array(c).reshape(xx.shape[0], 1, 1)
287
+ )
288
+ self.ww_ImagT[1] = self.backend.constant(
289
+ np.array(s).reshape(xx.shape[0], 1, 1)
290
+ )
291
+ self.ww_SmoothT[1] = self.backend.constant(
292
+ np.array(w).reshape(xx.shape[0], 1, 1)
293
+ )
229
294
 
230
- self.ww_SmoothT[1] = self.backend.constant(self.w_smooth.reshape(KERNELSZ,KERNELSZ,1,1))
231
- www=np.zeros([KERNELSZ,KERNELSZ,NORIENT,NORIENT],dtype=self.all_type)
295
+ else:
296
+ self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
297
+ self.ww_RealT = {}
298
+ self.ww_ImagT = {}
299
+ self.ww_SmoothT = {}
300
+
301
+ self.ww_SmoothT[1] = self.backend.constant(
302
+ self.w_smooth.reshape(KERNELSZ, KERNELSZ, 1, 1)
303
+ )
304
+ www = np.zeros([KERNELSZ, KERNELSZ, NORIENT, NORIENT], dtype=self.all_type)
232
305
  for k in range(NORIENT):
233
- www[:,:,k,k]=self.w_smooth.reshape(KERNELSZ,KERNELSZ)
234
- self.ww_SmoothT[NORIENT] = self.backend.constant(www.reshape(KERNELSZ,KERNELSZ,NORIENT,NORIENT))
235
- self.ww_RealT[1]=self.backend.constant(self.backend.bk_reshape(wwc.astype(self.all_type),[KERNELSZ,KERNELSZ,1,NORIENT]))
236
- self.ww_ImagT[1]=self.backend.constant(self.backend.bk_reshape(wws.astype(self.all_type),[KERNELSZ,KERNELSZ,1,NORIENT]))
306
+ www[:, :, k, k] = self.w_smooth.reshape(KERNELSZ, KERNELSZ)
307
+ self.ww_SmoothT[NORIENT] = self.backend.constant(
308
+ www.reshape(KERNELSZ, KERNELSZ, NORIENT, NORIENT)
309
+ )
310
+ self.ww_RealT[1] = self.backend.constant(
311
+ self.backend.bk_reshape(
312
+ wwc.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
313
+ )
314
+ )
315
+ self.ww_ImagT[1] = self.backend.constant(
316
+ self.backend.bk_reshape(
317
+ wws.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
318
+ )
319
+ )
320
+
237
321
  def doorientw(x):
238
- y=np.zeros([KERNELSZ,KERNELSZ,NORIENT,NORIENT*NORIENT],dtype=self.all_type)
322
+ y = np.zeros(
323
+ [KERNELSZ, KERNELSZ, NORIENT, NORIENT * NORIENT],
324
+ dtype=self.all_type,
325
+ )
239
326
  for k in range(NORIENT):
240
- y[:,:,k,k*NORIENT:k*NORIENT+NORIENT]=x.reshape(KERNELSZ,KERNELSZ,NORIENT)
327
+ y[:, :, k, k * NORIENT : k * NORIENT + NORIENT] = x.reshape(
328
+ KERNELSZ, KERNELSZ, NORIENT
329
+ )
241
330
  return y
242
- self.ww_RealT[NORIENT]=self.backend.constant(doorientw(wwc.astype(self.all_type)))
243
- self.ww_ImagT[NORIENT]=self.backend.constant(doorientw(wws.astype(self.all_type)))
244
- self.pix_interp_val={}
245
- self.weight_interp_val={}
246
- self.ring2nest={}
247
- self.nest2R={}
248
- self.nest2R1={}
249
- self.nest2R2={}
250
- self.nest2R3={}
251
- self.nest2R4={}
252
- self.inv_nest2R={}
253
- self.remove_border={}
254
-
255
- self.ampnorm={}
256
-
331
+
332
+ self.ww_RealT[NORIENT] = self.backend.constant(
333
+ doorientw(wwc.astype(self.all_type))
334
+ )
335
+ self.ww_ImagT[NORIENT] = self.backend.constant(
336
+ doorientw(wws.astype(self.all_type))
337
+ )
338
+ self.pix_interp_val = {}
339
+ self.weight_interp_val = {}
340
+ self.ring2nest = {}
341
+ self.nest2R = {}
342
+ self.nest2R1 = {}
343
+ self.nest2R2 = {}
344
+ self.nest2R3 = {}
345
+ self.nest2R4 = {}
346
+ self.inv_nest2R = {}
347
+ self.remove_border = {}
348
+
349
+ self.ampnorm = {}
350
+
257
351
  for i in range(nstep_max):
258
- lout=(2**i)
259
- self.pix_interp_val[lout]={}
260
- self.weight_interp_val[lout]={}
352
+ lout = 2**i
353
+ self.pix_interp_val[lout] = {}
354
+ self.weight_interp_val[lout] = {}
261
355
  for j in range(nstep_max):
262
- lout2=(2**j)
263
- self.pix_interp_val[lout][lout2]=None
264
- self.weight_interp_val[lout][lout2]=None
265
- self.ring2nest[lout]=None
266
- self.Idx_Neighbours[lout]=None
267
- self.nest2R[lout]=None
268
- self.nest2R1[lout]=None
269
- self.nest2R2[lout]=None
270
- self.nest2R3[lout]=None
271
- self.nest2R4[lout]=None
272
- self.inv_nest2R[lout]=None
273
- self.remove_border[lout]=None
274
-
275
- self.loss={}
356
+ lout2 = 2**j
357
+ self.pix_interp_val[lout][lout2] = None
358
+ self.weight_interp_val[lout][lout2] = None
359
+ self.ring2nest[lout] = None
360
+ self.Idx_Neighbours[lout] = None
361
+ self.nest2R[lout] = None
362
+ self.nest2R1[lout] = None
363
+ self.nest2R2[lout] = None
364
+ self.nest2R3[lout] = None
365
+ self.nest2R4[lout] = None
366
+ self.inv_nest2R[lout] = None
367
+ self.remove_border[lout] = None
368
+ self.ww_CNN_Transpose[lout] = None
369
+ self.ww_CNN[lout] = None
370
+ self.X_CNN[lout] = None
371
+ self.Y_CNN[lout] = None
372
+ self.Z_CNN[lout] = None
373
+
374
+ self.loss = {}
276
375
 
277
376
  def get_type(self):
278
377
  return self.all_type
279
378
 
280
379
  def get_mpi_type(self):
281
380
  return self.MPI_ALL_TYPE
282
-
381
+
283
382
  # ---------------------------------------------−---------
284
383
  # -- COMPUTE 3X3 INDEX FOR HEALPIX WORK --
285
384
  # ---------------------------------------------−---------
286
- def conv_to_FoCUS(self,x,axis=0):
287
- if self.use_2D and isinstance(x,np.ndarray):
288
- return(self.to_R(x,axis,chans=self.chans))
385
+ def conv_to_FoCUS(self, x, axis=0):
386
+ if self.use_2D and isinstance(x, np.ndarray):
387
+ return self.to_R(x, axis, chans=self.chans)
289
388
  return x
290
389
 
291
- def diffang(self,a,b):
292
- return np.arctan2(np.sin(a)-np.sin(b),np.cos(a)-np.cos(b))
293
-
294
- def corr_idx_wXX(self,x,y):
295
- idx=np.where(x==-1)[0]
296
- res=x
297
- res[idx]=y[idx]
298
- return(res)
299
-
390
+ def diffang(self, a, b):
391
+ return np.arctan2(np.sin(a) - np.sin(b), np.cos(a) - np.cos(b))
392
+
393
+ def corr_idx_wXX(self, x, y):
394
+ idx = np.where(x == -1)[0]
395
+ res = x
396
+ res[idx] = y[idx]
397
+ return res
398
+
399
+ # ---------------------------------------------−---------
400
+ # make the CNN working : index reporjection of the kernel on healpix
401
+
402
+ def calc_indices_convol(self, nside, kernel, rotation=None):
403
+ to, po = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
404
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside * nside), nest=True)
405
+
406
+ idx = np.argsort((x - 1.0) ** 2 + y**2 + z**2)[0:kernel]
407
+ x0, y0, z0 = hp.pix2vec(nside, idx[0], nest=True)
408
+ t0, p0 = hp.pix2ang(nside, idx[0], nest=True)
409
+
410
+ idx = np.argsort((x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2)[0:kernel]
411
+ im = np.ones([12 * nside**2]) * -1
412
+ im[idx] = np.arange(len(idx))
413
+
414
+ xc, yc, zc = hp.pix2vec(nside, idx, nest=True)
415
+
416
+ xc -= x0
417
+ yc -= y0
418
+ zc -= z0
419
+
420
+ vec = np.concatenate(
421
+ [np.expand_dims(x, -1), np.expand_dims(y, -1), np.expand_dims(z, -1)], 1
422
+ )
423
+
424
+ indices = np.zeros([12 * nside**2 * 250, 2], dtype="int")
425
+ weights = np.zeros([12 * nside**2 * 250])
426
+ nn = 0
427
+ for k in range(12 * nside * nside):
428
+ if k % (nside * nside) == nside * nside - 1:
429
+ print(
430
+ "Nside=%d KenelSZ=%d %.2f%%"
431
+ % (nside, kernel, k / (12 * nside**2) * 100)
432
+ )
433
+ if nside < 4:
434
+ idx2 = np.arange(12 * nside**2)
435
+ else:
436
+ idx2 = hp.query_disc(
437
+ nside, vec[k], np.pi / nside, inclusive=True, nest=True
438
+ )
439
+ t2, p2 = hp.pix2ang(nside, idx2, nest=True)
440
+ if rotation is None:
441
+ rot = [po[k] / np.pi * 180.0, (t0 - to[k]) / np.pi * 180.0]
442
+ else:
443
+ rot = [po[k] / np.pi * 180.0, (t0 - to[k]) / np.pi * 180.0, rotation[k]]
444
+
445
+ r = hp.Rotator(rot=rot)
446
+ t2, p2 = r(t2, p2)
447
+
448
+ ii, ww = hp.get_interp_weights(nside, t2, p2, nest=True)
449
+
450
+ ii = im[ii]
451
+
452
+ for l_rotation in range(4):
453
+ iii = np.where(ii[l_rotation] != -1)[0]
454
+ if len(iii) > 0:
455
+ indices[nn : nn + len(iii), 1] = idx2[iii]
456
+ indices[nn : nn + len(iii), 0] = k * kernel + ii[l_rotation, iii]
457
+ weights[nn : nn + len(iii)] = ww[l_rotation, iii]
458
+ nn += len(iii)
459
+
460
+ indices = indices[0:nn]
461
+ weights = weights[0:nn]
462
+ if k % (nside * nside) == nside * nside - 1:
463
+ print(
464
+ "Nside=%d KenelSZ=%d Total Number of value=%d Ratio of the matrix %.2g%%"
465
+ % (
466
+ nside,
467
+ kernel,
468
+ nn,
469
+ 100 * nn / (kernel * 12 * nside**2 * 12 * nside**2),
470
+ )
471
+ )
472
+ return indices, weights, xc, yc, zc
473
+
474
+ # ---------------------------------------------−---------
475
+ def calc_orientation(self, im): # im is [Ndata,12*Nside**2]
476
+ nside = int(np.sqrt(im.shape[1] // 12))
477
+ l_kernel = self.KERNELSZ * self.KERNELSZ
478
+ norient = 32
479
+ w = np.zeros([l_kernel, 1, 2 * norient])
480
+ ca = np.cos(np.arange(norient) / norient * np.pi)
481
+ sa = np.sin(np.arange(norient) / norient * np.pi)
482
+ stat = np.zeros([12 * nside**2, norient])
483
+
484
+ if self.ww_CNN[nside] is None:
485
+ self.init_CNN_index(nside, transpose=False)
486
+
487
+ y = self.Y_CNN[nside]
488
+ z = self.Z_CNN[nside]
489
+
490
+ for k in range(norient):
491
+ w[:, 0, k] = np.exp(-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)) * np.cos(
492
+ nside * (y * ca[k] + z * sa[k]) * np.pi / 2
493
+ )
494
+ w[:, 0, k + norient] = np.exp(
495
+ -0.5 * nside**2 * ((y) ** 2 + (z) ** 2)
496
+ ) * np.sin(nside * (y * ca[k] + z * sa[k]) * np.pi / 2)
497
+ w[:, 0, k] = w[:, 0, k] - np.mean(w[:, 0, k])
498
+ w[:, 0, k + norient] = w[:, 0, k] - np.mean(w[:, 0, k + norient])
499
+
500
+ for k in range(im.shape[0]):
501
+ tmp = im[k].reshape(12 * nside**2, 1)
502
+ im2 = self.healpix_layer(tmp, w)
503
+ stat = stat + im2[:, 0:norient] ** 2 + im2[:, norient:] ** 2
504
+
505
+ rotation = (np.argmax(stat, 1)).astype("float") / 32.0 * 180.0
506
+
507
+ indices, weights, x, y, z = self.calc_indices_convol(
508
+ nside, 9, rotation=rotation
509
+ )
510
+
511
+ return indices, weights
512
+
513
+ def init_CNN_index(self, nside, transpose=False):
514
+ l_kernel = int(self.KERNELSZ * self.KERNELSZ)
515
+ try:
516
+ indices = np.load(
517
+ "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
518
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
519
+ )
520
+ weights = np.load(
521
+ "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
522
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
523
+ )
524
+ xc = np.load(
525
+ "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
526
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
527
+ )
528
+ yc = np.load(
529
+ "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
530
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
531
+ )
532
+ zc = np.load(
533
+ "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
534
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
535
+ )
536
+ except:
537
+ indices, weights, xc, yc, zc = self.calc_indices_convol(nside, l_kernel)
538
+ np.save(
539
+ "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
540
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
541
+ indices,
542
+ )
543
+ np.save(
544
+ "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
545
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
546
+ weights,
547
+ )
548
+ np.save(
549
+ "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
550
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
551
+ xc,
552
+ )
553
+ np.save(
554
+ "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
555
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
556
+ yc,
557
+ )
558
+ np.save(
559
+ "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
560
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
561
+ zc,
562
+ )
563
+ if not self.silent:
564
+ print(
565
+ "Write %s/FOSCAT_%s_W%d_%d_%d_CNNV2.npy"
566
+ % (
567
+ self.TEMPLATE_PATH,
568
+ TMPFILE_VERSION,
569
+ l_kernel,
570
+ self.NORIENT,
571
+ nside,
572
+ )
573
+ )
574
+
575
+ self.X_CNN[nside] = xc
576
+ self.Y_CNN[nside] = yc
577
+ self.Z_CNN[nside] = zc
578
+ self.ww_CNN[nside] = self.backend.bk_SparseTensor(
579
+ indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
580
+ )
581
+
582
+ # ---------------------------------------------−---------
583
+ def healpix_layer_coord(self, im, axis=0):
584
+ nside = int(np.sqrt(im.shape[axis] // 12))
585
+ if self.ww_CNN[nside] is None:
586
+ self.init_CNN_index(nside)
587
+ return self.X_CNN[nside], self.Y_CNN[nside], self.Z_CNN[nside]
588
+
589
+ # ---------------------------------------------−---------
590
+ def healpix_layer_transpose(self, im, ww, indices=None, weights=None, axis=0):
591
+ nside = int(np.sqrt(im.shape[axis] // 12))
592
+
593
+ if im.shape[1 + axis] != ww.shape[1]:
594
+ if not self.silent:
595
+ print("Weights channels should be equal to the input image channels")
596
+ return -1
597
+ if axis == 1:
598
+ results = []
599
+
600
+ for k in range(im.shape[0]):
601
+
602
+ tmp = self.healpix_layer(
603
+ im[k], ww, indices=indices, weights=weights, axis=0
604
+ )
605
+ tmp = self.backend.bk_reshape(
606
+ self.up_grade(tmp, 2 * nside), [12 * 4 * nside**2, ww.shape[2]]
607
+ )
608
+
609
+ results.append(tmp)
610
+
611
+ return self.backend.bk_stack(results, axis=0)
612
+ else:
613
+ tmp = self.healpix_layer(
614
+ im, ww, indices=indices, weights=weights, axis=axis
615
+ )
616
+
617
+ return self.up_grade(tmp, 2 * nside)
618
+
619
+ # ---------------------------------------------−---------
620
+ # ---------------------------------------------−---------
621
+ def healpix_layer(self, im, ww, indices=None, weights=None, axis=0):
622
+ nside = int(np.sqrt(im.shape[axis] // 12))
623
+ l_kernel = self.KERNELSZ * self.KERNELSZ
624
+
625
+ if im.shape[1 + axis] != ww.shape[1]:
626
+ if not self.silent:
627
+ print("Weights channels should be equal to the input image channels")
628
+ return -1
629
+
630
+ if indices is None:
631
+ if self.ww_CNN[nside] is None:
632
+ self.init_CNN_index(nside, transpose=False)
633
+ mat = self.ww_CNN[nside]
634
+ else:
635
+ if weights is None:
636
+ print(
637
+ "healpix_layer : If indices is not none weights should be specify"
638
+ )
639
+ return 0
640
+
641
+ mat = self.backend.bk_SparseTensor(
642
+ indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
643
+ )
644
+
645
+ if axis == 1:
646
+ results = []
647
+
648
+ for k in range(im.shape[0]):
649
+
650
+ tmp = self.backend.bk_sparse_dense_matmul(mat, im[k])
651
+
652
+ density = self.backend.bk_reshape(
653
+ tmp, [12 * nside * nside, l_kernel * im.shape[1 + axis]]
654
+ )
655
+
656
+ density = self.backend.bk_matmul(
657
+ density,
658
+ self.backend.bk_reshape(
659
+ ww, [l_kernel * im.shape[1 + axis], ww.shape[2]]
660
+ ),
661
+ )
662
+
663
+ results.append(
664
+ self.backend.bk_reshape(density, [12 * nside**2, ww.shape[2]])
665
+ )
666
+
667
+ return self.backend.bk_stack(results, axis=0)
668
+ else:
669
+ tmp = self.backend.bk_sparse_dense_matmul(mat, im)
670
+
671
+ density = self.backend.bk_reshape(
672
+ tmp, [12 * nside * nside, l_kernel * im.shape[1]]
673
+ )
674
+
675
+ return self.backend.bk_matmul(
676
+ density,
677
+ self.backend.bk_reshape(ww, [l_kernel * im.shape[1], ww.shape[2]]),
678
+ )
679
+
680
+ # ---------------------------------------------−---------
681
+
300
682
  # ---------------------------------------------−---------
301
683
  def get_rank(self):
302
- return(self.rank)
684
+ return self.rank
685
+
303
686
  # ---------------------------------------------−---------
304
687
  def get_size(self):
305
- return(self.size)
306
-
688
+ return self.size
689
+
307
690
  # ---------------------------------------------−---------
308
691
  def barrier(self):
309
692
  if self.isMPI:
310
693
  self.comm.Barrier()
311
-
694
+
312
695
  # ---------------------------------------------−---------
313
- def toring(self,image,axis=0):
314
- lout=int(np.sqrt(image.shape[axis]//12))
315
-
696
+ def toring(self, image, axis=0):
697
+ lout = int(np.sqrt(image.shape[axis] // 12))
698
+
316
699
  if self.ring2nest[lout] is None:
317
- self.ring2nest[lout]=hp.ring2nest(lout,np.arange(12*lout**2))
318
-
700
+ self.ring2nest[lout] = hp.ring2nest(lout, np.arange(12 * lout**2))
701
+
319
702
  return image.numpy()[self.ring2nest[lout]]
320
703
 
321
- #--------------------------------------------------------
322
- def ud_grade(self,im,j,axis=0):
323
- rim=im
704
+ # --------------------------------------------------------
705
+ def ud_grade(self, im, j, axis=0):
706
+ rim = im
324
707
  for k in range(j):
325
- rim=self.smooth(rim,axis=axis)
326
- rim=self.ud_grade_2(rim,axis=axis)
708
+ rim = self.smooth(rim, axis=axis)
709
+ rim = self.ud_grade_2(rim, axis=axis)
327
710
  return rim
328
-
329
- #--------------------------------------------------------
330
- def ud_grade_2(self,im,axis=0):
331
-
711
+
712
+ # --------------------------------------------------------
713
+ def ud_grade_2(self, im, axis=0):
714
+
332
715
  if self.use_2D:
333
- ishape=list(im.shape)
334
- if len(ishape)<axis+2:
335
- print('Use of 2D scat with data that has less than 2D')
336
- exit(0)
337
-
338
- npix=im.shape[axis]
339
- npiy=im.shape[axis+1]
340
- odata=1
341
- if len(ishape)>axis+2:
342
- for k in range(axis+2,len(ishape)):
343
- odata=odata*ishape[k]
344
-
345
- ndata=1
716
+ ishape = list(im.shape)
717
+ if len(ishape) < axis + 2:
718
+ if not self.silent:
719
+ print("Use of 2D scat with data that has less than 2D")
720
+ return None
721
+
722
+ npix = im.shape[axis]
723
+ npiy = im.shape[axis + 1]
724
+ odata = 1
725
+ if len(ishape) > axis + 2:
726
+ for k in range(axis + 2, len(ishape)):
727
+ odata = odata * ishape[k]
728
+
729
+ ndata = 1
346
730
  for k in range(axis):
347
- ndata=ndata*ishape[k]
731
+ ndata = ndata * ishape[k]
732
+
733
+ tim = self.backend.bk_reshape(
734
+ self.backend.bk_cast(im), [ndata, npix, npiy, odata]
735
+ )
736
+ tim = self.backend.bk_reshape(
737
+ tim[:, 0 : 2 * (npix // 2), 0 : 2 * (npiy // 2), :],
738
+ [ndata, npix // 2, 2, npiy // 2, 2, odata],
739
+ )
740
+
741
+ res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
742
+
743
+ if axis == 0:
744
+ if len(ishape) == 2:
745
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
746
+ else:
747
+ return self.backend.bk_reshape(
748
+ res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
749
+ )
750
+ else:
751
+ if len(ishape) == axis + 2:
752
+ return self.backend.bk_reshape(
753
+ res, ishape[0:axis] + [npix // 2, npiy // 2]
754
+ )
755
+ else:
756
+ return self.backend.bk_reshape(
757
+ res,
758
+ ishape[0:axis] + [npix // 2, npiy // 2] + ishape[axis + 2 :],
759
+ )
760
+
761
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
762
+ elif self.use_1D:
763
+ ishape = list(im.shape)
764
+ if len(ishape) < axis + 1:
765
+ if not self.silent:
766
+ print("Use of 1D scat with data that has less than 1D")
767
+ return None
768
+
769
+ npix = im.shape[axis]
770
+ odata = 1
771
+ if len(ishape) > axis + 1:
772
+ for k in range(axis + 1, len(ishape)):
773
+ odata = odata * ishape[k]
774
+
775
+ ndata = 1
776
+ for k in range(axis):
777
+ ndata = ndata * ishape[k]
778
+
779
+ tim = self.backend.bk_reshape(
780
+ self.backend.bk_cast(im), [ndata, npix, odata]
781
+ )
782
+ tim = self.backend.bk_reshape(
783
+ tim[:, 0 : 2 * (npix // 2), :], [ndata, npix // 2, 2, odata]
784
+ )
348
785
 
349
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
350
- tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),0:2*(npiy//2),:],[ndata,npix//2,2,npiy//2,2,odata])
786
+ res = self.backend.bk_reduce_mean(tim, 2)
351
787
 
352
- res=self.backend.bk_reduce_mean(self.backend.bk_reduce_mean(tim,4),2)
353
-
354
- if axis==0:
355
- if len(ishape)==2:
356
- return self.backend.bk_reshape(res,[npix//2,npiy//2])
788
+ if axis == 0:
789
+ if len(ishape) == 1:
790
+ return self.backend.bk_reshape(res, [npix // 2])
357
791
  else:
358
- return self.backend.bk_reshape(res,[npix//2,npiy//2]+ishape[axis+2:])
792
+ return self.backend.bk_reshape(
793
+ res, [npix // 2] + ishape[axis + 1 :]
794
+ )
359
795
  else:
360
- if len(ishape)==axis+2:
361
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2])
796
+ if len(ishape) == axis + 1:
797
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2])
362
798
  else:
363
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2]+ishape[axis+2:])
364
-
365
- return self.backend.bk_reshape(res,[npix//2,npiy//2])
366
-
799
+ return self.backend.bk_reshape(
800
+ res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
801
+ )
802
+
803
+ return self.backend.bk_reshape(res, [npix // 2])
804
+
367
805
  else:
368
- shape=list(im.shape)
369
-
370
- lout=int(np.sqrt(shape[axis]//12))
371
- if im.__class__==np.zeros([0]).__class__:
372
- oshape=np.zeros([len(shape)+1],dtype='int')
373
- if axis>0:
374
- oshape[0:axis]=shape[0:axis]
375
- oshape[axis]=12*lout*lout//4
376
- oshape[axis+1]=4
377
- if len(shape)>axis:
378
- oshape[axis+2:]=shape[axis+1:]
806
+ shape = list(im.shape)
807
+
808
+ lout = int(np.sqrt(shape[axis] // 12))
809
+ if im.__class__ == np.zeros([0]).__class__:
810
+ oshape = np.zeros([len(shape) + 1], dtype="int")
811
+ if axis > 0:
812
+ oshape[0:axis] = shape[0:axis]
813
+ oshape[axis] = 12 * lout * lout // 4
814
+ oshape[axis + 1] = 4
815
+ if len(shape) > axis:
816
+ oshape[axis + 2 :] = shape[axis + 1 :]
379
817
  else:
380
- if axis>0:
381
- oshape=shape[0:axis]+[12*lout*lout//4,4]
818
+ if axis > 0:
819
+ oshape = shape[0:axis] + [12 * lout * lout // 4, 4]
382
820
  else:
383
- oshape=[12*lout*lout//4,4]
384
- if len(shape)>axis:
385
- oshape=oshape+shape[axis+1:]
386
-
387
- return(self.backend.bk_reduce_mean(self.backend.bk_reshape(im,oshape),axis=axis+1))
388
-
389
- #--------------------------------------------------------
390
- def up_grade(self,im,nout,axis=0,nouty=None):
391
-
821
+ oshape = [12 * lout * lout // 4, 4]
822
+ if len(shape) > axis:
823
+ oshape = oshape + shape[axis + 1 :]
824
+
825
+ return self.backend.bk_reduce_mean(
826
+ self.backend.bk_reshape(im, oshape), axis=axis + 1
827
+ )
828
+
829
+ # --------------------------------------------------------
830
+ def up_grade(self, im, nout, axis=0, nouty=None):
831
+
392
832
  if self.use_2D:
393
- ishape=list(im.shape)
394
- if len(ishape)<axis+2:
395
- print('Use of 2D scat with data that has less than 2D')
396
- exit(0)
397
-
833
+ ishape = list(im.shape)
834
+ if len(ishape) < axis + 2:
835
+ if not self.silent:
836
+ print("Use of 2D scat with data that has less than 2D")
837
+ return None
838
+
398
839
  if nouty is None:
399
- nouty=nout
400
-
401
- if ishape[axis]==nout and ishape[axis+1]==nouty:
840
+ nouty = nout
841
+
842
+ if ishape[axis] == nout and ishape[axis + 1] == nouty:
402
843
  return im
403
-
404
- npix=im.shape[axis]
405
- npiy=im.shape[axis+1]
406
- odata=1
407
- if len(ishape)>axis+2:
408
- for k in range(axis+2,len(ishape)):
409
- odata=odata*ishape[k]
410
-
411
- ndata=1
844
+
845
+ npix = im.shape[axis]
846
+ npiy = im.shape[axis + 1]
847
+ odata = 1
848
+ if len(ishape) > axis + 2:
849
+ for k in range(axis + 2, len(ishape)):
850
+ odata = odata * ishape[k]
851
+
852
+ ndata = 1
412
853
  for k in range(axis):
413
- ndata=ndata*ishape[k]
854
+ ndata = ndata * ishape[k]
855
+
856
+ tim = self.backend.bk_reshape(
857
+ self.backend.bk_cast(im), [ndata, npix, npiy, odata]
858
+ )
859
+
860
+ res = self.backend.bk_resize_image(tim, [nout, nouty])
861
+
862
+ if axis == 0:
863
+ if len(ishape) == 2:
864
+ return self.backend.bk_reshape(res, [nout, nouty])
865
+ else:
866
+ return self.backend.bk_reshape(
867
+ res, [nout, nouty] + ishape[axis + 2 :]
868
+ )
869
+ else:
870
+ if len(ishape) == axis + 2:
871
+ return self.backend.bk_reshape(res, ishape[0:axis] + [nout, nouty])
872
+ else:
873
+ return self.backend.bk_reshape(
874
+ res, ishape[0:axis] + [nout, nouty] + ishape[axis + 2 :]
875
+ )
876
+
877
+ return self.backend.bk_reshape(res, [nout, nouty])
414
878
 
415
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
879
+ elif self.use_1D:
880
+ ishape = list(im.shape)
881
+ if len(ishape) < axis + 1:
882
+ if not self.silent:
883
+ print("Use of 1D scat with data that has less than 1D")
884
+ return None
416
885
 
417
- res=self.backend.bk_resize_image(tim,[nout,nouty])
418
-
419
- if axis==0:
420
- if len(ishape)==2:
421
- return self.backend.bk_reshape(res,[nout,nouty])
886
+ if ishape[axis] == nout:
887
+ return im
888
+
889
+ npix = im.shape[axis]
890
+ odata = 1
891
+ if len(ishape) > axis + 1:
892
+ for k in range(axis + 1, len(ishape)):
893
+ odata = odata * ishape[k]
894
+
895
+ ndata = 1
896
+ for k in range(axis):
897
+ ndata = ndata * ishape[k]
898
+
899
+ tim = self.backend.bk_reshape(
900
+ self.backend.bk_cast(im), [ndata, npix, odata]
901
+ )
902
+
903
+ while tim.shape[1] != nout:
904
+ res2 = self.backend.bk_expand_dims(
905
+ self.backend.bk_concat(
906
+ [(tim[:, 1:, :] + 3 * tim[:, :-1, :]) / 4, tim[:, -1:, :]], 1
907
+ ),
908
+ -2,
909
+ )
910
+ res1 = self.backend.bk_expand_dims(
911
+ self.backend.bk_concat(
912
+ [tim[:, 0:1, :], (tim[:, 1:, :] * 3 + tim[:, :-1, :]) / 4], 1
913
+ ),
914
+ -2,
915
+ )
916
+ tim = self.backend.bk_reshape(
917
+ self.backend.bk_concat([res1, res2], -2),
918
+ [ndata, tim.shape[1] * 2, odata],
919
+ )
920
+
921
+ if axis == 0:
922
+ if len(ishape) == 1:
923
+ return self.backend.bk_reshape(tim, [nout])
422
924
  else:
423
- return self.backend.bk_reshape(res,[nout,nouty]+ishape[axis+2:])
925
+ return self.backend.bk_reshape(tim, [nout] + ishape[axis + 1 :])
424
926
  else:
425
- if len(ishape)==axis+2:
426
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty])
927
+ if len(ishape) == axis + 1:
928
+ return self.backend.bk_reshape(tim, ishape[0:axis] + [nout])
427
929
  else:
428
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty]+ishape[axis+2:])
429
-
430
- return self.backend.bk_reshape(res,[nout,nouty])
431
-
930
+ return self.backend.bk_reshape(
931
+ tim, ishape[0:axis] + [nout] + ishape[axis + 1 :]
932
+ )
933
+
934
+ return self.backend.bk_reshape(tim, [nout])
935
+
432
936
  else:
433
937
 
434
- lout=int(np.sqrt(im.shape[axis]//12))
435
-
938
+ lout = int(np.sqrt(im.shape[axis] // 12))
939
+
436
940
  if self.pix_interp_val[lout][nout] is None:
437
- print('compute lout nout',lout,nout)
438
- th,ph=hp.pix2ang(nout,np.arange(12*nout**2,dtype='int'),nest=True)
439
- p, w = hp.get_interp_weights(lout,th,ph,nest=True)
941
+ if not self.silent:
942
+ print("compute lout nout", lout, nout)
943
+ th, ph = hp.pix2ang(
944
+ nout, np.arange(12 * nout**2, dtype="int"), nest=True
945
+ )
946
+ p, w = hp.get_interp_weights(lout, th, ph, nest=True)
440
947
  del th
441
948
  del ph
442
-
443
- indice=np.zeros([12*nout*nout*4,2],dtype='int')
444
- p=p.T
445
- w=w.T
446
- t=np.argsort(p,1).flatten() # to make oder indices for sparsematrix computation
447
- t=(t+np.repeat(np.arange(12*nout*nout)*4,4))
448
- p=p.flatten()[t]
449
- w=w.flatten()[t]
450
- indice[:,0]=np.repeat(np.arange(12*nout**2),4)
451
- indice[:,1]=p
452
-
453
- self.pix_interp_val[lout][nout]=1
454
- self.weight_interp_val[lout][nout] = self.backend.bk_SparseTensor(self.backend.constant(indice), \
455
- self.backend.constant(self.backend.bk_cast(w.flatten())), \
456
- dense_shape=[12*nout**2,12*lout**2])
457
-
458
- if lout==nout:
459
- imout=im
949
+
950
+ indice = np.zeros([12 * nout * nout * 4, 2], dtype="int")
951
+ p = p.T
952
+ w = w.T
953
+ t = np.argsort(
954
+ p, 1
955
+ ).flatten() # to make oder indices for sparsematrix computation
956
+ t = t + np.repeat(np.arange(12 * nout * nout) * 4, 4)
957
+ p = p.flatten()[t]
958
+ w = w.flatten()[t]
959
+ indice[:, 0] = np.repeat(np.arange(12 * nout**2), 4)
960
+ indice[:, 1] = p
961
+
962
+ self.pix_interp_val[lout][nout] = 1
963
+ self.weight_interp_val[lout][nout] = self.backend.bk_SparseTensor(
964
+ self.backend.constant(indice),
965
+ self.backend.constant(self.backend.bk_cast(w.flatten())),
966
+ dense_shape=[12 * nout**2, 12 * lout**2],
967
+ )
968
+
969
+ if lout == nout:
970
+ imout = im
460
971
  else:
461
972
 
462
- ishape=list(im.shape)
463
- odata=1
464
- for k in range(axis+1,len(ishape)):
465
- odata=odata*ishape[k]
466
-
467
- ndata=1
973
+ ishape = list(im.shape)
974
+ odata = 1
975
+ for k in range(axis + 1, len(ishape)):
976
+ odata = odata * ishape[k]
977
+
978
+ ndata = 1
468
979
  for k in range(axis):
469
- ndata=ndata*ishape[k]
470
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,12*lout**2,odata])
471
- if tim.dtype==self.all_cbk_type:
472
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
473
- ,self.backend.bk_real(tim[0])),[1,12*nout**2,odata])
474
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
475
- ,self.backend.bk_imag(tim[0])),[1,12*nout**2,odata])
476
- imout=self.backend.bk_complex(rr,ii)
980
+ ndata = ndata * ishape[k]
981
+ tim = self.backend.bk_reshape(
982
+ self.backend.bk_cast(im), [ndata, 12 * lout**2, odata]
983
+ )
984
+ if tim.dtype == self.all_cbk_type:
985
+ rr = self.backend.bk_reshape(
986
+ self.backend.bk_sparse_dense_matmul(
987
+ self.weight_interp_val[lout][nout],
988
+ self.backend.bk_real(tim[0]),
989
+ ),
990
+ [1, 12 * nout**2, odata],
991
+ )
992
+ ii = self.backend.bk_reshape(
993
+ self.backend.bk_sparse_dense_matmul(
994
+ self.weight_interp_val[lout][nout],
995
+ self.backend.bk_imag(tim[0]),
996
+ ),
997
+ [1, 12 * nout**2, odata],
998
+ )
999
+ imout = self.backend.bk_complex(rr, ii)
477
1000
  else:
478
- imout=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
479
- ,tim[0]),[1,12*nout**2,odata])
480
-
481
- for k in range(1,ndata):
482
- if tim.dtype==self.all_cbk_type:
483
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
484
- ,self.backend.bk_real(tim[k])),[1,12*nout**2,odata])
485
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
486
- ,self.backend.bk_imag(tim[k])),[1,12*nout**2,odata])
487
- imout=self.backend.bk_concat([imout,self.backend.bk_complex(rr,ii)],0)
1001
+ imout = self.backend.bk_reshape(
1002
+ self.backend.bk_sparse_dense_matmul(
1003
+ self.weight_interp_val[lout][nout], tim[0]
1004
+ ),
1005
+ [1, 12 * nout**2, odata],
1006
+ )
1007
+
1008
+ for k in range(1, ndata):
1009
+ if tim.dtype == self.all_cbk_type:
1010
+ rr = self.backend.bk_reshape(
1011
+ self.backend.bk_sparse_dense_matmul(
1012
+ self.weight_interp_val[lout][nout],
1013
+ self.backend.bk_real(tim[k]),
1014
+ ),
1015
+ [1, 12 * nout**2, odata],
1016
+ )
1017
+ ii = self.backend.bk_reshape(
1018
+ self.backend.bk_sparse_dense_matmul(
1019
+ self.weight_interp_val[lout][nout],
1020
+ self.backend.bk_imag(tim[k]),
1021
+ ),
1022
+ [1, 12 * nout**2, odata],
1023
+ )
1024
+ imout = self.backend.bk_concat(
1025
+ [imout, self.backend.bk_complex(rr, ii)], 0
1026
+ )
488
1027
  else:
489
- imout=self.backend.bk_concat([imout,self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
490
- ,tim[k]),[1,12*nout**2,odata])],0)
491
-
492
- if axis==0:
493
- if len(ishape)==1:
494
- return self.backend.bk_reshape(imout,[12*nout**2])
1028
+ imout = self.backend.bk_concat(
1029
+ [
1030
+ imout,
1031
+ self.backend.bk_reshape(
1032
+ self.backend.bk_sparse_dense_matmul(
1033
+ self.weight_interp_val[lout][nout], tim[k]
1034
+ ),
1035
+ [1, 12 * nout**2, odata],
1036
+ ),
1037
+ ],
1038
+ 0,
1039
+ )
1040
+
1041
+ if axis == 0:
1042
+ if len(ishape) == 1:
1043
+ return self.backend.bk_reshape(imout, [12 * nout**2])
495
1044
  else:
496
- return self.backend.bk_reshape(imout,[12*nout**2]+ishape[axis+1:])
1045
+ return self.backend.bk_reshape(
1046
+ imout, [12 * nout**2] + ishape[axis + 1 :]
1047
+ )
497
1048
  else:
498
- if len(ishape)==axis+1:
499
- return self.backend.bk_reshape(imout,ishape[0:axis]+[12*nout**2])
1049
+ if len(ishape) == axis + 1:
1050
+ return self.backend.bk_reshape(
1051
+ imout, ishape[0:axis] + [12 * nout**2]
1052
+ )
500
1053
  else:
501
- return self.backend.bk_reshape(imout,ishape[0:axis]+[12*nout**2]+ishape[axis+1:])
502
- return(imout)
503
-
504
- #--------------------------------------------------------
505
- def fill_1d(self,i_arr,nullval=0):
506
- arr=i_arr.copy()
1054
+ return self.backend.bk_reshape(
1055
+ imout, ishape[0:axis] + [12 * nout**2] + ishape[axis + 1 :]
1056
+ )
1057
+ return imout
1058
+
1059
+ # --------------------------------------------------------
1060
+ def fill_1d(self, i_arr, nullval=0):
1061
+ arr = i_arr.copy()
507
1062
  # Indices des éléments non nuls
508
- non_zero_indices = np.where(arr!=nullval)[0]
509
-
1063
+ non_zero_indices = np.where(arr != nullval)[0]
1064
+
510
1065
  # Indices de tous les éléments
511
1066
  all_indices = np.arange(len(arr))
512
-
1067
+
513
1068
  # Interpoler linéairement en utilisant np.interp
514
1069
  # np.interp(x, xp, fp) : x sont les indices pour lesquels on veut obtenir des valeurs
515
1070
  # xp sont les indices des données existantes, fp sont les valeurs des données existantes
516
- interpolated_values = np.interp(all_indices, non_zero_indices, arr[non_zero_indices])
517
-
1071
+ interpolated_values = np.interp(
1072
+ all_indices, non_zero_indices, arr[non_zero_indices]
1073
+ )
1074
+
518
1075
  # Mise à jour du tableau original
519
- arr[arr==nullval] = interpolated_values[arr==nullval]
520
-
1076
+ arr[arr == nullval] = interpolated_values[arr == nullval]
1077
+
521
1078
  return arr
522
1079
 
523
- def fill_2d(self,i_arr,nullval=0):
524
- arr=i_arr.copy()
1080
+ def fill_2d(self, i_arr, nullval=0):
1081
+ arr = i_arr.copy()
525
1082
  # Créer une grille de coordonnées correspondant aux indices du tableau
526
1083
  x, y = np.indices(arr.shape)
527
-
1084
+
528
1085
  # Extraire les coordonnées des points non nuls ainsi que leurs valeurs
529
1086
  non_zero_points = np.array((x[arr != nullval], y[arr != nullval])).T
530
1087
  non_zero_values = arr[arr != nullval]
531
-
1088
+
532
1089
  # Extraire les coordonnées des points nuls
533
1090
  zero_points = np.array((x[arr == nullval], y[arr == nullval])).T
534
1091
 
535
1092
  # Interpolation linéaire
536
- interpolated_values = griddata(non_zero_points, non_zero_values, zero_points, method='linear')
1093
+ interpolated_values = griddata(
1094
+ non_zero_points, non_zero_values, zero_points, method="linear"
1095
+ )
537
1096
 
538
1097
  # Remplacer les valeurs nulles par les valeurs interpolées
539
1098
  arr[arr == nullval] = interpolated_values
540
1099
 
541
1100
  return arr
542
-
543
- def fill_healpy(self,i_map,nmax=10,nullval=hp.UNSEEN):
544
- map=1*i_map
1101
+
1102
+ def fill_healpy(self, i_map, nmax=10, nullval=hp.UNSEEN):
1103
+ map = 1 * i_map
545
1104
  # Trouver les pixels nuls
546
1105
  nside = hp.npix2nside(len(map))
547
1106
  null_indices = np.where(map == nullval)[0]
548
-
549
- itt=0
550
- while null_indices.shape[0]>0 and itt<nmax:
1107
+
1108
+ itt = 0
1109
+ while null_indices.shape[0] > 0 and itt < nmax:
551
1110
  # Trouver les coordonnées theta, phi pour les pixels nuls
552
1111
  theta, phi = hp.pix2ang(nside, null_indices)
553
-
1112
+
554
1113
  # Interpoler les valeurs en utilisant les pixels voisins
555
1114
  # La fonction get_interp_val peut être utilisée pour obtenir les valeurs interpolées
556
1115
  # pour des positions données en theta et phi.
557
1116
  i_idx = hp.get_all_neighbours(nside, theta, phi)
558
-
559
- i_w=(map[i_idx]!=nullval)*(i_idx!=-1)
560
- vv=np.sum(i_w,0)
561
- interpolated_values=np.sum(i_w*map[i_idx],0)
1117
+
1118
+ i_w = (map[i_idx] != nullval) * (i_idx != -1)
1119
+ vv = np.sum(i_w, 0)
1120
+ interpolated_values = np.sum(i_w * map[i_idx], 0)
562
1121
 
563
1122
  # Remplacer les valeurs nulles par les valeurs interpolées
564
- map[null_indices[vv>0]] = interpolated_values[vv>0]/vv[vv>0]
1123
+ map[null_indices[vv > 0]] = interpolated_values[vv > 0] / vv[vv > 0]
565
1124
 
566
1125
  null_indices = np.where(map == nullval)[0]
567
- itt+=1
568
-
1126
+ itt += 1
1127
+
569
1128
  return map
570
-
571
- #--------------------------------------------------------
572
- def ud_grade_1d(self,im,nout,axis=0):
573
- npix=im.shape[axis]
574
-
575
- ishape=list(im.shape)
576
- odata=1
577
- for k in range(axis+1,len(ishape)):
578
- odata=odata*ishape[k]
579
-
580
- ndata=1
581
- for k in range(axis):
582
- ndata=ndata*ishape[k]
583
1129
 
584
- nscale=npix//nout
585
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix//nscale,nscale,odata])
1130
+ # --------------------------------------------------------
1131
+ def ud_grade_1d(self, im, nout, axis=0):
1132
+ npix = im.shape[axis]
1133
+
1134
+ ishape = list(im.shape)
1135
+ odata = 1
1136
+ for k in range(axis + 1, len(ishape)):
1137
+ odata = odata * ishape[k]
586
1138
 
587
- res = self.backend.bk_reduce_mean(tim,2)
588
-
589
- if axis==0:
590
- if len(ishape)==1:
591
- return self.backend.bk_reshape(res,[nout])
1139
+ ndata = 1
1140
+ for k in range(axis):
1141
+ ndata = ndata * ishape[k]
1142
+
1143
+ nscale = npix // nout
1144
+ if npix % nscale == 0:
1145
+ tim = self.backend.bk_reshape(
1146
+ self.backend.bk_cast(im), [ndata, npix // nscale, nscale, odata]
1147
+ )
1148
+ else:
1149
+ im = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1150
+ tim = self.backend.bk_reshape(
1151
+ self.backend.bk_cast(im[:, 0 : nscale * (npix // nscale), :]),
1152
+ [ndata, npix // nscale, nscale, odata],
1153
+ )
1154
+ res = self.backend.bk_reduce_mean(tim, 2)
1155
+
1156
+ if axis == 0:
1157
+ if len(ishape) == 1:
1158
+ return self.backend.bk_reshape(res, [nout])
592
1159
  else:
593
- return self.backend.bk_reshape(res,[nout]+ishape[axis+1:])
1160
+ return self.backend.bk_reshape(res, [nout] + ishape[axis + 1 :])
594
1161
  else:
595
- if len(ishape)==axis+1:
596
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout])
1162
+ if len(ishape) == axis + 1:
1163
+ return self.backend.bk_reshape(res, ishape[0:axis] + [nout])
597
1164
  else:
598
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout]+ishape[axis+1:])
599
- return self.backend.bk_reshape(res,[nout])
600
-
601
- #--------------------------------------------------------
602
- def up_grade_2_1d(self,im,axis=0):
603
-
604
- npix=im.shape[axis]
605
-
606
- ishape=list(im.shape)
607
- odata=1
608
- for k in range(axis+1,len(ishape)):
609
- odata=odata*ishape[k]
610
-
611
- ndata=1
1165
+ return self.backend.bk_reshape(
1166
+ res, ishape[0:axis] + [nout] + ishape[axis + 1 :]
1167
+ )
1168
+ return self.backend.bk_reshape(res, [nout])
1169
+
1170
+ # --------------------------------------------------------
1171
+ def up_grade_2_1d(self, im, axis=0):
1172
+
1173
+ npix = im.shape[axis]
1174
+
1175
+ ishape = list(im.shape)
1176
+ odata = 1
1177
+ for k in range(axis + 1, len(ishape)):
1178
+ odata = odata * ishape[k]
1179
+
1180
+ ndata = 1
612
1181
  for k in range(axis):
613
- ndata=ndata*ishape[k]
614
-
615
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
616
-
617
- res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
618
- res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
619
- res = self.backend.bk_concat([res1,res2],-2)
620
-
621
- if axis==0:
622
- if len(ishape)==1:
623
- return self.backend.bk_reshape(res,[npix*2])
1182
+ ndata = ndata * ishape[k]
1183
+
1184
+ tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1185
+
1186
+ res2 = self.backend.bk_expand_dims(
1187
+ self.backend.bk_concat(
1188
+ [(tim[:, 1:, :] + 3 * tim[:, :-1, :]) / 4, tim[:, -1:, :]], 1
1189
+ ),
1190
+ -2,
1191
+ )
1192
+ res1 = self.backend.bk_expand_dims(
1193
+ self.backend.bk_concat(
1194
+ [tim[:, 0:1, :], (tim[:, 1:, :] * 3 + tim[:, :-1, :]) / 4], 1
1195
+ ),
1196
+ -2,
1197
+ )
1198
+ res = self.backend.bk_concat([res1, res2], -2)
1199
+
1200
+ if axis == 0:
1201
+ if len(ishape) == 1:
1202
+ return self.backend.bk_reshape(res, [npix * 2])
624
1203
  else:
625
- return self.backend.bk_reshape(res,[npix*2]+ishape[axis+1:])
1204
+ return self.backend.bk_reshape(res, [npix * 2] + ishape[axis + 1 :])
626
1205
  else:
627
- if len(ishape)==axis+1:
628
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix*2])
1206
+ if len(ishape) == axis + 1:
1207
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix * 2])
629
1208
  else:
630
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix*2]+ishape[axis+1:])
631
- return self.backend.bk_reshape(res,[npix*2])
632
-
633
-
634
- #--------------------------------------------------------
635
- def convol_1d(self,im,axis=0):
636
-
637
- xx=np.arange(5)-2
638
- w=np.exp(-0.17328679514*(xx)**2)
639
- c=np.cos((xx)*np.pi/2)
640
- s=np.sin((xx)*np.pi/2)
641
-
642
- wr=np.array(w*c).reshape(xx.shape[0],1,1)
643
- wi=np.array(w*s).reshape(xx.shape[0],1,1)
644
-
645
- npix=im.shape[axis]
646
-
647
- ishape=list(im.shape)
648
- odata=1
649
- for k in range(axis+1,len(ishape)):
650
- odata=odata*ishape[k]
651
-
652
- ndata=1
1209
+ return self.backend.bk_reshape(
1210
+ res, ishape[0:axis] + [npix * 2] + ishape[axis + 1 :]
1211
+ )
1212
+ return self.backend.bk_reshape(res, [npix * 2])
1213
+
1214
+ # --------------------------------------------------------
1215
+ def convol_1d(self, im, axis=0):
1216
+
1217
+ xx = np.arange(5) - 2
1218
+ w = np.exp(-0.17328679514 * (xx) ** 2)
1219
+ c = np.cos((xx) * np.pi / 2)
1220
+ s = np.sin((xx) * np.pi / 2)
1221
+
1222
+ wr = np.array(w * c).reshape(xx.shape[0], 1, 1)
1223
+ wi = np.array(w * s).reshape(xx.shape[0], 1, 1)
1224
+
1225
+ npix = im.shape[axis]
1226
+
1227
+ ishape = list(im.shape)
1228
+ odata = 1
1229
+ for k in range(axis + 1, len(ishape)):
1230
+ odata = odata * ishape[k]
1231
+
1232
+ ndata = 1
653
1233
  for k in range(axis):
654
- ndata=ndata*ishape[k]
655
-
656
- if odata>1:
657
- wr=np.repeat(wr,odata,2)
658
- wi=np.repeat(wi,odata,2)
659
-
660
- wr=self.backend.bk_cast(self.backend.constant(wr))
661
- wi=self.backend.bk_cast(self.backend.constant(wi))
662
-
663
- tim = self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
664
-
665
- if tim.dtype==self.all_cbk_type:
666
- rr1 = self.backend.bk_conv1d(self.backend.bk_real(tim),wr)
667
- ii1 = self.backend.bk_conv1d(self.backend.bk_real(tim),wi)
668
- rr2 = self.backend.bk_conv1d(self.backend.bk_imag(tim),wr)
669
- ii2 = self.backend.bk_conv1d(self.backend.bk_imag(tim),wi)
670
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
1234
+ ndata = ndata * ishape[k]
1235
+
1236
+ if odata > 1:
1237
+ wr = np.repeat(wr, odata, 2)
1238
+ wi = np.repeat(wi, odata, 2)
1239
+
1240
+ wr = self.backend.bk_cast(self.backend.constant(wr))
1241
+ wi = self.backend.bk_cast(self.backend.constant(wi))
1242
+
1243
+ tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1244
+
1245
+ if tim.dtype == self.all_cbk_type:
1246
+ rr1 = self.backend.bk_conv1d(self.backend.bk_real(tim), wr)
1247
+ ii1 = self.backend.bk_conv1d(self.backend.bk_real(tim), wi)
1248
+ rr2 = self.backend.bk_conv1d(self.backend.bk_imag(tim), wr)
1249
+ ii2 = self.backend.bk_conv1d(self.backend.bk_imag(tim), wi)
1250
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
671
1251
  else:
672
- rr = self.backend.bk_conv1d(tim,wr)
673
- ii = self.backend.bk_conv1d(tim,wi)
674
-
675
- res=self.backend.bk_complex(rr,ii)
676
-
677
- if axis==0:
678
- if len(ishape)==1:
679
- return self.backend.bk_reshape(res,[npix])
1252
+ rr = self.backend.bk_conv1d(tim, wr)
1253
+ ii = self.backend.bk_conv1d(tim, wi)
1254
+
1255
+ res = self.backend.bk_complex(rr, ii)
1256
+
1257
+ if axis == 0:
1258
+ if len(ishape) == 1:
1259
+ return self.backend.bk_reshape(res, [npix])
680
1260
  else:
681
- return self.backend.bk_reshape(res,[npix]+ishape[axis+1:])
1261
+ return self.backend.bk_reshape(res, [npix] + ishape[axis + 1 :])
682
1262
  else:
683
- if len(ishape)==axis+1:
684
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix])
1263
+ if len(ishape) == axis + 1:
1264
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix])
685
1265
  else:
686
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix]+ishape[axis+1:])
687
- return self.backend.bk_reshape(res,[npix])
688
-
689
-
690
- #--------------------------------------------------------
691
- def smooth_1d(self,im,axis=0):
692
-
693
- xx=np.arange(5)-2
694
- w=np.exp(-0.17328679514*(xx)**2)
695
- w=w/w.sum()
696
- w=np.array(w).reshape(xx.shape[0],1,1)
697
-
698
- npix=im.shape[axis]
699
-
700
- ishape=list(im.shape)
701
- odata=1
702
- for k in range(axis+1,len(ishape)):
703
- odata=odata*ishape[k]
704
-
705
- ndata=1
1266
+ return self.backend.bk_reshape(
1267
+ res, ishape[0:axis] + [npix] + ishape[axis + 1 :]
1268
+ )
1269
+ return self.backend.bk_reshape(res, [npix])
1270
+
1271
+ # --------------------------------------------------------
1272
+ def smooth_1d(self, im, axis=0):
1273
+
1274
+ xx = np.arange(5) - 2
1275
+ w = np.exp(-0.17328679514 * (xx) ** 2)
1276
+ w = w / w.sum()
1277
+ w = np.array(w).reshape(xx.shape[0], 1, 1)
1278
+
1279
+ npix = im.shape[axis]
1280
+
1281
+ ishape = list(im.shape)
1282
+ odata = 1
1283
+ for k in range(axis + 1, len(ishape)):
1284
+ odata = odata * ishape[k]
1285
+
1286
+ ndata = 1
706
1287
  for k in range(axis):
707
- ndata=ndata*ishape[k]
708
-
709
- if odata>1:
710
- w=np.repeat(w,odata,2)
711
-
712
- w=self.backend.bk_cast(self.backend.constant(w))
713
-
714
- tim = self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
715
-
716
- if tim.dtype==self.all_cbk_type:
717
- rr = self.backend.bk_conv1d(self.backend.bk_real(tim),w)
718
- ii = self.backend.bk_conv1d(self.backend.bk_real(tim),w)
719
- res=self.backend.bk_complex(rr,ii)
1288
+ ndata = ndata * ishape[k]
1289
+
1290
+ if odata > 1:
1291
+ w = np.repeat(w, odata, 2)
1292
+
1293
+ w = self.backend.bk_cast(self.backend.constant(w))
1294
+
1295
+ tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1296
+
1297
+ if tim.dtype == self.all_cbk_type:
1298
+ rr = self.backend.bk_conv1d(self.backend.bk_real(tim), w)
1299
+ ii = self.backend.bk_conv1d(self.backend.bk_real(tim), w)
1300
+ res = self.backend.bk_complex(rr, ii)
720
1301
  else:
721
- res=self.backend.bk_conv1d(tim,w)
722
-
723
- if axis==0:
724
- if len(ishape)==1:
725
- return self.backend.bk_reshape(res,[npix])
1302
+ res = self.backend.bk_conv1d(tim, w)
1303
+
1304
+ if axis == 0:
1305
+ if len(ishape) == 1:
1306
+ return self.backend.bk_reshape(res, [npix])
726
1307
  else:
727
- return self.backend.bk_reshape(res,[npix]+ishape[axis+1:])
1308
+ return self.backend.bk_reshape(res, [npix] + ishape[axis + 1 :])
728
1309
  else:
729
- if len(ishape)==axis+1:
730
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix])
1310
+ if len(ishape) == axis + 1:
1311
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix])
731
1312
  else:
732
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix]+ishape[axis+1:])
733
- return self.backend.bk_reshape(res,[npix])
734
-
735
- #--------------------------------------------------------
736
- def up_grade_1d(self,im,nout,axis=0):
737
-
738
- lout=int(im.shape[axis])
739
- nscale=int(np.log(nout//lout)/np.log(2))
740
- res=self.backend.bk_cast(im)
1313
+ return self.backend.bk_reshape(
1314
+ res, ishape[0:axis] + [npix] + ishape[axis + 1 :]
1315
+ )
1316
+ return self.backend.bk_reshape(res, [npix])
1317
+
1318
+ # --------------------------------------------------------
1319
+ def up_grade_1d(self, im, nout, axis=0):
1320
+
1321
+ lout = int(im.shape[axis])
1322
+ nscale = int(np.log(nout // lout) / np.log(2))
1323
+ res = self.backend.bk_cast(im)
741
1324
  for k in range(nscale):
742
- res=self.up_grade_2_1d(res,axis=axis)
743
- return(res)
744
-
1325
+ res = self.up_grade_2_1d(res, axis=axis)
1326
+ return res
1327
+
745
1328
  # ---------------------------------------------−---------
746
- def init_index(self,nside,kernel=-1):
1329
+ def init_index(self, nside, kernel=-1):
747
1330
 
748
- if kernel==-1:
749
- l_kernel=self.KERNELSZ
1331
+ if kernel == -1:
1332
+ l_kernel = self.KERNELSZ
750
1333
  else:
751
- l_kernel=kernel
752
-
753
-
1334
+ l_kernel = kernel
1335
+
754
1336
  try:
755
1337
  if self.use_2D:
756
- tmp=np.load('%s/W%d_%s_%d_IDX.npy'%(self.TEMPLATE_PATH,l_kernel**2,TMPFILE_VERSION,nside))
1338
+ tmp = np.load(
1339
+ "%s/W%d_%s_%d_IDX.npy"
1340
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1341
+ )
757
1342
  else:
758
- tmp=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel**2,self.NORIENT,nside))
1343
+ tmp = np.load(
1344
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1345
+ % (
1346
+ self.TEMPLATE_PATH,
1347
+ TMPFILE_VERSION,
1348
+ l_kernel**2,
1349
+ self.NORIENT,
1350
+ nside,
1351
+ )
1352
+ )
759
1353
  except:
760
- if self.use_2D==False:
1354
+ if not self.use_2D:
1355
+
1356
+ if l_kernel == 5:
1357
+ pw = 0.5
1358
+ pw2 = 0.5
1359
+ threshold = 2e-4
1360
+
1361
+ elif l_kernel == 3:
1362
+ pw = 1.0 / np.sqrt(2)
1363
+ pw2 = 1.0
1364
+ threshold = 1e-3
1365
+
1366
+ elif l_kernel == 7:
1367
+ pw = 0.5
1368
+ pw2 = 0.25
1369
+ threshold = 4e-5
1370
+
1371
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1372
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1373
+
1374
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1375
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1376
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1377
+
1378
+ indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1379
+ indice = np.zeros(
1380
+ [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1381
+ )
1382
+ wav = np.zeros(
1383
+ [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1384
+ )
1385
+ wwav = np.zeros([12 * nside * nside * 64 * self.NORIENT], dtype="float")
1386
+
1387
+ iv = 0
1388
+ iv2 = 0
1389
+ for iii in range(12 * nside * nside):
1390
+
1391
+ if iii % (nside * nside) == nside * nside - 1:
1392
+ if not self.silent:
1393
+ print(
1394
+ "Pre-compute nside=%6d %.2f%%"
1395
+ % (nside, 100 * iii / (12 * nside * nside))
1396
+ )
1397
+
1398
+ hidx = hp.query_disc(
1399
+ nside, [x[iii], y[iii], z[iii]], 2 * np.pi / nside, nest=True
1400
+ )
1401
+
1402
+ R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1403
+
1404
+ t2, p2 = R(th[hidx], ph[hidx])
1405
+
1406
+ vec2 = hp.ang2vec(t2, p2)
1407
+
1408
+ x2 = vec2[:, 0]
1409
+ y2 = vec2[:, 1]
1410
+ z2 = vec2[:, 2]
1411
+
1412
+ ww = np.exp(
1413
+ -pw2
1414
+ * ((nside) ** 2)
1415
+ * ((x2) ** 2 + (y2) ** 2 + (z2 - 1.0) ** 2)
1416
+ )
1417
+ idx = np.where((ww**2) > threshold)[0]
1418
+ nval2 = len(idx)
1419
+ indice2[iv2 : iv2 + nval2, 0] = iii
1420
+ indice2[iv2 : iv2 + nval2, 1] = hidx[idx]
1421
+ wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1422
+ iv2 += nval2
1423
+
1424
+ for l_rotation in range(self.NORIENT):
1425
+
1426
+ angle = (
1427
+ l_rotation / 4.0 * np.pi
1428
+ - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1429
+ - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1430
+ )
1431
+
1432
+ # posi=2*(0.5-(z[hidx]<0))
1433
+
1434
+ axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1435
+ wresr = ww * np.cos(pw * axes * (nside) * np.pi)
1436
+ wresi = ww * np.sin(pw * axes * (nside) * np.pi)
1437
+
1438
+ vnorm = wresr * wresr + wresi * wresi
1439
+ idx = np.where(vnorm > threshold)[0]
1440
+
1441
+ nval = len(idx)
1442
+ indice[iv : iv + nval, 0] = iii * 4 + l_rotation
1443
+ indice[iv : iv + nval, 1] = hidx[idx]
1444
+ # print([hidx[k] for k in idx])
1445
+ # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1446
+ normr = np.mean(wresr[idx])
1447
+ normi = np.mean(wresi[idx])
1448
+
1449
+ val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1450
+ val = val / abs(val).sum()
1451
+
1452
+ wav[iv : iv + nval] = val
1453
+ iv += nval
1454
+
1455
+ indice = indice[:iv, :]
1456
+ wav = wav[:iv]
1457
+ indice2 = indice2[:iv2, :]
1458
+ wwav = wwav[:iv2]
1459
+ if not self.silent:
1460
+ print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1461
+ """
1462
+ # OLD VERSION OLD VERSION OLD VERSION (3.0)
761
1463
  if self.KERNELSZ*self.KERNELSZ>12*nside*nside:
762
- l_kernel=2*nside
763
-
1464
+ l_kernel=3
1465
+
764
1466
  aa=np.cos(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
765
1467
  bb=np.sin(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
766
1468
  x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
@@ -777,36 +1479,42 @@ class FoCUS:
777
1479
  lidx=np.arange(12*nside*nside)
778
1480
 
779
1481
  pw=np.pi/4.0
780
- pw2=1/2.0
781
-
1482
+ pw2=1/2
1483
+ amp=1.0
1484
+
782
1485
  if l_kernel==5:
783
1486
  pw=np.pi/4.0
784
- pw2=1/2.0
1487
+ pw2=1/2.25
1488
+ amp=1.0/9.2038
1489
+
785
1490
  elif l_kernel==3:
786
- pw=1.0
1491
+ pw=1.0/np.sqrt(2)
787
1492
  pw2=1.0
1493
+ amp=1/8.45
1494
+
788
1495
  elif l_kernel==7:
789
1496
  pw=np.pi/4.0
790
1497
  pw2=1.0/3.0
791
-
1498
+
792
1499
  for k in range(12*nside*nside):
793
1500
  if k%(nside*nside)==0:
794
- print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
1501
+ if not self.silent:
1502
+ print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
795
1503
  if nside>scale*2:
796
1504
  lidx=hp.get_all_neighbours(nside//scale,th[k//(scale*scale)],ph[k//(scale*scale)],nest=True)
797
1505
  lidx=np.concatenate([lidx,np.array([(k//(scale*scale))])],0)
798
1506
  lidx=np.repeat(lidx*(scale*scale),(scale*scale))+ \
799
1507
  np.tile(np.arange((scale*scale)),lidx.shape[0])
800
-
1508
+
801
1509
  delta=(x[lidx]-x[k])**2+(y[lidx]-y[k])**2+(z[lidx]-z[k])**2
802
1510
  pidx=np.where(delta<(10)/(nside**2))[0]
803
1511
  if len(pidx)<l_kernel**2:
804
1512
  pidx=np.arange(delta.shape[0])
805
-
1513
+
806
1514
  w=np.exp(-pw2*delta[pidx]*(nside**2))
807
1515
  pidx=pidx[np.argsort(-w)[0:l_kernel**2]]
808
1516
  pidx=pidx[np.argsort(lidx[pidx])]
809
-
1517
+
810
1518
  w=np.exp(-pw2*delta[pidx]*(nside**2))
811
1519
  iwav[k]=lidx[pidx]
812
1520
  wwav[k]=w
@@ -814,16 +1522,16 @@ class FoCUS:
814
1522
  r=hp.Rotator(rot=rot)
815
1523
  ty,tx=r(to[iwav[k]],po[iwav[k]])
816
1524
  ty=ty-np.pi/2
817
-
1525
+
818
1526
  xx=np.expand_dims(pw*nside*np.pi*tx/np.cos(ty),-1)
819
1527
  yy=np.expand_dims(pw*nside*np.pi*ty,-1)
820
-
1528
+
821
1529
  wav[k,:,:]=(np.cos(xx*aa+yy*bb)+complex(0.0,1.0)*np.sin(xx*aa+yy*bb))*np.expand_dims(w,-1)
822
-
1530
+
823
1531
  wav=wav-np.expand_dims(np.mean(wav,1),1)
824
- wav=wav/np.expand_dims(np.std(wav,1),1)
1532
+ wav=amp*wav/np.expand_dims(np.std(wav,1),1)
825
1533
  wwav=wwav/np.expand_dims(np.sum(wwav,1),1)
826
-
1534
+
827
1535
  nk=l_kernel*l_kernel
828
1536
  indice=np.zeros([12*nside*nside*nk*self.NORIENT,2],dtype='int')
829
1537
  lidx=np.arange(self.NORIENT)
@@ -835,561 +1543,1164 @@ class FoCUS:
835
1543
  for i in range(12*nside*nside):
836
1544
  indice2[i*nk:i*nk+nk,0]=i
837
1545
  indice2[i*nk:i*nk+nk,1]=iwav[i]
838
-
1546
+
839
1547
  w=np.zeros([12*nside*nside,wav.shape[2],wav.shape[1]],dtype='complex')
840
1548
  for i in range(wav.shape[1]):
841
1549
  for j in range(wav.shape[2]):
842
1550
  w[:,j,i]=wav[:,i,j]
843
1551
  wav=w.flatten()
844
1552
  wwav=wwav.flatten()
845
-
846
- print('Write FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
847
- np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice)
848
- np.save('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),wav)
849
- np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice2)
850
- np.save('%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),wwav)
1553
+ """
1554
+ if not self.silent:
1555
+ print(
1556
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1557
+ % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1558
+ )
1559
+ np.save(
1560
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1561
+ % (
1562
+ self.TEMPLATE_PATH,
1563
+ TMPFILE_VERSION,
1564
+ self.KERNELSZ**2,
1565
+ self.NORIENT,
1566
+ nside,
1567
+ ),
1568
+ indice,
1569
+ )
1570
+ np.save(
1571
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1572
+ % (
1573
+ self.TEMPLATE_PATH,
1574
+ TMPFILE_VERSION,
1575
+ self.KERNELSZ**2,
1576
+ self.NORIENT,
1577
+ nside,
1578
+ ),
1579
+ wav,
1580
+ )
1581
+ np.save(
1582
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1583
+ % (
1584
+ self.TEMPLATE_PATH,
1585
+ TMPFILE_VERSION,
1586
+ self.KERNELSZ**2,
1587
+ self.NORIENT,
1588
+ nside,
1589
+ ),
1590
+ indice2,
1591
+ )
1592
+ np.save(
1593
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1594
+ % (
1595
+ self.TEMPLATE_PATH,
1596
+ TMPFILE_VERSION,
1597
+ self.KERNELSZ**2,
1598
+ self.NORIENT,
1599
+ nside,
1600
+ ),
1601
+ wwav,
1602
+ )
851
1603
  else:
852
- if l_kernel**2==9:
853
- if self.rank==0:
1604
+ if l_kernel**2 == 9:
1605
+ if self.rank == 0:
854
1606
  self.comp_idx_w9(nside)
855
- elif l_kernel**2==25:
856
- if self.rank==0:
1607
+ elif l_kernel**2 == 25:
1608
+ if self.rank == 0:
857
1609
  self.comp_idx_w25(nside)
858
1610
  else:
859
- if self.rank==0:
860
- print('Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d'%(KERNELSZ,KERNELSZ))
861
- exit(0)
862
-
863
- self.barrier()
864
- if self.use_2D:
865
- tmp=np.load('%s/W%d_%s_%d_IDX.npy'%(self.TEMPLATE_PATH,l_kernel**2,TMPFILE_VERSION,nside))
1611
+ if self.rank == 0:
1612
+ if not self.silent:
1613
+ print(
1614
+ "Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d"
1615
+ % (self.KERNELSZ, self.KERNELSZ)
1616
+ )
1617
+ return None
1618
+
1619
+ self.barrier()
1620
+ if self.use_2D:
1621
+ tmp = np.load(
1622
+ "%s/W%d_%s_%d_IDX.npy"
1623
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1624
+ )
866
1625
  else:
867
- tmp=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
868
- tmp2=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
869
- wr=np.load('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside)).real
870
- wi=np.load('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside)).imag
871
- ws=self.slope*np.load('%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
872
-
873
- wr=self.backend.bk_SparseTensor(self.backend.constant(tmp),self.backend.constant(self.backend.bk_cast(wr)),dense_shape=[12*nside**2*self.NORIENT,12*nside**2])
874
- wi=self.backend.bk_SparseTensor(self.backend.constant(tmp),self.backend.constant(self.backend.bk_cast(wi)),dense_shape=[12*nside**2*self.NORIENT,12*nside**2])
875
- ws=self.backend.bk_SparseTensor(self.backend.constant(tmp2),self.backend.constant(self.backend.bk_cast(ws)),dense_shape=[12*nside**2,12*nside**2])
876
-
877
- if kernel==-1:
878
- self.Idx_Neighbours[nside]=tmp
879
-
1626
+ tmp = np.load(
1627
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1628
+ % (
1629
+ self.TEMPLATE_PATH,
1630
+ TMPFILE_VERSION,
1631
+ self.KERNELSZ**2,
1632
+ self.NORIENT,
1633
+ nside,
1634
+ )
1635
+ )
1636
+ tmp2 = np.load(
1637
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1638
+ % (
1639
+ self.TEMPLATE_PATH,
1640
+ TMPFILE_VERSION,
1641
+ self.KERNELSZ**2,
1642
+ self.NORIENT,
1643
+ nside,
1644
+ )
1645
+ )
1646
+ wr = np.load(
1647
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1648
+ % (
1649
+ self.TEMPLATE_PATH,
1650
+ TMPFILE_VERSION,
1651
+ self.KERNELSZ**2,
1652
+ self.NORIENT,
1653
+ nside,
1654
+ )
1655
+ ).real
1656
+ wi = np.load(
1657
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1658
+ % (
1659
+ self.TEMPLATE_PATH,
1660
+ TMPFILE_VERSION,
1661
+ self.KERNELSZ**2,
1662
+ self.NORIENT,
1663
+ nside,
1664
+ )
1665
+ ).imag
1666
+ ws = self.slope * np.load(
1667
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1668
+ % (
1669
+ self.TEMPLATE_PATH,
1670
+ TMPFILE_VERSION,
1671
+ self.KERNELSZ**2,
1672
+ self.NORIENT,
1673
+ nside,
1674
+ )
1675
+ )
1676
+
1677
+ wr = self.backend.bk_SparseTensor(
1678
+ self.backend.constant(tmp),
1679
+ self.backend.constant(self.backend.bk_cast(wr)),
1680
+ dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1681
+ )
1682
+ wi = self.backend.bk_SparseTensor(
1683
+ self.backend.constant(tmp),
1684
+ self.backend.constant(self.backend.bk_cast(wi)),
1685
+ dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1686
+ )
1687
+ ws = self.backend.bk_SparseTensor(
1688
+ self.backend.constant(tmp2),
1689
+ self.backend.constant(self.backend.bk_cast(ws)),
1690
+ dense_shape=[12 * nside**2, 12 * nside**2],
1691
+ )
1692
+
1693
+ if kernel == -1:
1694
+ self.Idx_Neighbours[nside] = tmp
1695
+
880
1696
  if self.use_2D:
881
- if kernel!=-1:
1697
+ if kernel != -1:
882
1698
  return tmp
883
-
884
- return wr,wi,ws,tmp
885
-
886
- # ---------------------------------------------−---------
887
- # Compute x [....,a,....] to [....,a*a,....]
888
- #NOT YET TESTED OR IMPLEMENTED
889
- def auto_cross_2(x,axis=0):
890
- shape=np.array(x.shape)
891
- if axis==0:
892
- y1=self.reshape(x,[shape[0],1,np.cumprod(shape[1:])])
893
- y2=self.reshape(x,[1,shape[0],np.cumprod(shape[1:])])
894
- oshape=np.concat([shape[0],shape[0],shape[1:]])
895
- return(self.reshape(y1*y2,oshape))
896
-
897
- # ---------------------------------------------−---------
898
- # Compute x [....,a,....,b,....] to [....,b*b,....,a*a,....]
899
- #NOT YET TESTED OR IMPLEMENTED
900
- def auto_cross_2(x,axis1=0,axis2=1):
901
- shape=np.array(x.shape)
902
- if axis==0:
903
- y1=self.reshape(x,[shape[0],1,np.cumprod(shape[1:])])
904
- y2=self.reshape(x,[1,shape[0],np.cumprod(shape[1:])])
905
- oshape=np.concat([shape[0],shape[0],shape[1:]])
906
- return(self.reshape(y1*y2,oshape))
907
-
908
-
1699
+
1700
+ return wr, wi, ws, tmp
1701
+
909
1702
  # ---------------------------------------------−---------
910
1703
  # convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
911
- def swapaxes(self,x,axis1,axis2):
912
- shape=list(x.shape)
913
- if axis1<0:
914
- laxis1=len(shape)+axis1
1704
+ def swapaxes(self, x, axis1, axis2):
1705
+ shape = list(x.shape)
1706
+ if axis1 < 0:
1707
+ laxis1 = len(shape) + axis1
915
1708
  else:
916
- laxis1=axis1
917
- if axis2<0:
918
- laxis2=len(shape)+axis2
1709
+ laxis1 = axis1
1710
+ if axis2 < 0:
1711
+ laxis2 = len(shape) + axis2
919
1712
  else:
920
- laxis2=axis2
921
-
922
- naxes=len(shape)
923
- thelist=[i for i in range(naxes)]
924
- thelist[laxis1]=laxis2
925
- thelist[laxis2]=laxis1
926
- return self.backend.bk_transpose(x,thelist)
927
-
1713
+ laxis2 = axis2
1714
+
1715
+ naxes = len(shape)
1716
+ thelist = [i for i in range(naxes)]
1717
+ thelist[laxis1] = laxis2
1718
+ thelist[laxis2] = laxis1
1719
+ return self.backend.bk_transpose(x, thelist)
1720
+
928
1721
  # ---------------------------------------------−---------
929
1722
  # Mean using mask x [....,Npix,....], mask[Nmask,Npix] to [....,Nmask,....]
930
1723
  # if use_2D
931
1724
  # Mean using mask x [....,12,Nside+2*off,Nside+2*off,....], mask[Nmask,12,Nside+2*off,Nside+2*off] to [....,Nmask,....]
932
- def masked_mean(self,x,mask,axis=0,rank=0,calc_var=False):
933
-
934
- #==========================================================================
1725
+ def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
1726
+
1727
+ # ==========================================================================
935
1728
  # in input data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]]
936
1729
  # in input mask=[Nmask,X[,Y]]
937
1730
  # if self.use_2D : X[,Y]] = [X,Y]
938
1731
  # if second level: NORIENT[,NORIENT]= NORIENT,NORIENT
939
- #==========================================================================
940
-
941
- shape=list(x.shape)
942
-
1732
+ # ==========================================================================
1733
+
1734
+ shape = list(x.shape)
1735
+
943
1736
  if not self.use_2D:
944
- nside=int(np.sqrt(x.shape[axis]//12))
945
-
946
- l_mask=mask
1737
+ nside = int(np.sqrt(x.shape[axis] // 12))
1738
+
1739
+ l_mask = mask
947
1740
  if self.mask_norm:
948
- sum_mask=self.backend.bk_reduce_sum(self.backend.bk_reshape(l_mask,[l_mask.shape[0],np.prod(np.array(l_mask.shape[1:]))]),1)
1741
+ sum_mask = self.backend.bk_reduce_sum(
1742
+ self.backend.bk_reshape(
1743
+ l_mask, [l_mask.shape[0], np.prod(np.array(l_mask.shape[1:]))]
1744
+ ),
1745
+ 1,
1746
+ )
949
1747
  if not self.use_2D:
950
- l_mask=12*nside*nside*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
951
- else:
952
- l_mask=mask.shape[1]*mask.shape[2]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
953
-
954
- if self.use_2D:
955
- if self.padding=='VALID' and shape[axis]!=l_mask.shape[1]:
956
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
957
- if shape[axis]!=l_mask.shape[1]:
958
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1748
+ l_mask = (
1749
+ 12
1750
+ * nside
1751
+ * nside
1752
+ * l_mask
1753
+ / self.backend.bk_reshape(
1754
+ sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]]
1755
+ )
1756
+ )
1757
+ elif self.use_2D:
1758
+ l_mask = (
1759
+ mask.shape[1]
1760
+ * mask.shape[2]
1761
+ * l_mask
1762
+ / self.backend.bk_reshape(
1763
+ sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]]
1764
+ )
1765
+ )
959
1766
  else:
960
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
961
-
962
- # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,...,KERNELSZ//2:-self.KERNELSZ//2,KERNELSZ//2:-self.KERNELSZ//2,NORIENT[,NORIENT]]
1767
+ l_mask = (
1768
+ mask.shape[1]
1769
+ * l_mask
1770
+ / self.backend.bk_reshape(
1771
+ sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]]
1772
+ )
1773
+ )
1774
+
963
1775
  if self.use_2D:
964
- ichannel=1
1776
+ if self.padding == "VALID":
1777
+ l_mask = l_mask[
1778
+ :,
1779
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1780
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1781
+ ]
1782
+ if shape[axis] != l_mask.shape[1]:
1783
+ l_mask = l_mask[
1784
+ :,
1785
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1786
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1787
+ ]
1788
+
1789
+ ichannel = 1
965
1790
  for i in range(axis):
966
- ichannel*=shape[i]
967
- ochannel=1
968
- for i in range(axis+2,len(shape)):
969
- ochannel*=shape[i]
970
- l_x=self.backend.bk_reshape(x,[ichannel,shape[axis],shape[axis+1],ochannel])
971
- oshape=[k for k in shape]
972
- oshape[axis]=oshape[axis]-self.KERNELSZ+1
973
- oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
974
- l_x=self.backend.bk_reshape(l_x[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1791
+ ichannel *= shape[i]
1792
+ ochannel = 1
1793
+ for i in range(axis + 2, len(shape)):
1794
+ ochannel *= shape[i]
1795
+ l_x = self.backend.bk_reshape(
1796
+ x, [ichannel, 1, shape[axis], shape[axis + 1], ochannel]
1797
+ )
1798
+
1799
+ if self.padding == "VALID":
1800
+ oshape = [k for k in shape]
1801
+ oshape[axis] = oshape[axis] - self.KERNELSZ + 1
1802
+ oshape[axis + 1] = oshape[axis + 1] - self.KERNELSZ + 1
1803
+ l_x = self.backend.bk_reshape(
1804
+ l_x[
1805
+ :,
1806
+ :,
1807
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1808
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1809
+ :,
1810
+ ],
1811
+ oshape,
1812
+ )
1813
+
1814
+ elif self.use_1D:
1815
+ if self.padding == "VALID":
1816
+ l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1]
1817
+ if shape[axis] != l_mask.shape[1]:
1818
+ l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1]
1819
+
1820
+ ichannel = 1
1821
+ for i in range(axis):
1822
+ ichannel *= shape[i]
1823
+ ochannel = 1
1824
+ for i in range(axis + 1, len(shape)):
1825
+ ochannel *= shape[i]
1826
+ l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1827
+
1828
+ if self.padding == "VALID":
1829
+ oshape = [k for k in shape]
1830
+ oshape[axis] = oshape[axis] - self.KERNELSZ + 1
1831
+ l_x = self.backend.bk_reshape(
1832
+ l_x[:, :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, :], oshape
1833
+ )
975
1834
  else:
976
- l_x=x
977
-
1835
+ ichannel = 1
1836
+ for i in range(axis):
1837
+ ichannel *= shape[i]
1838
+ ochannel = 1
1839
+ for i in range(axis + 1, len(shape)):
1840
+ ochannel *= shape[i]
1841
+ l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1842
+
978
1843
  # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,1,...,X[,Y],NORIENT[,NORIENT]]
979
- l_x=self.backend.bk_expand_dims(l_x,1)
980
-
981
1844
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,X[,Y]]
982
- l_mask=self.backend.bk_expand_dims(l_mask,0)
983
-
984
- # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
985
- for i in range(1,axis):
986
- l_mask=self.backend.bk_expand_dims(l_mask,axis)
987
-
988
- if l_x.dtype==self.all_cbk_type:
989
- l_mask=self.backend.bk_complex(l_mask,self.backend.bk_cast(0.0*l_mask))
990
-
1845
+ l_mask = self.backend.bk_expand_dims(l_mask, 0)
1846
+ # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,X[,Y],1]
1847
+ l_mask = self.backend.bk_expand_dims(l_mask, -1)
1848
+
1849
+ if l_x.dtype == self.all_cbk_type:
1850
+ l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
1851
+
991
1852
  if self.use_2D:
1853
+ mtmp = l_mask
1854
+ vtmp = l_x
1855
+
1856
+ v1 = self.backend.bk_reduce_sum(
1857
+ self.backend.bk_reduce_sum(mtmp * vtmp, axis=2), 2
1858
+ )
1859
+ v2 = self.backend.bk_reduce_sum(
1860
+ self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2), 2
1861
+ )
1862
+ vh = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp, axis=2), 2)
1863
+
1864
+ res = v1 / vh
1865
+
1866
+ oshape = []
1867
+ if axis > 0:
1868
+ oshape = oshape + list(x.shape[0:axis])
1869
+ oshape = oshape + [mask.shape[0]]
1870
+ if axis + 1 < len(x.shape):
1871
+ oshape = oshape + list(x.shape[axis + 2 :])
992
1872
 
993
- # mask=[1,Nmask,....,X,Y] => mask=[1,Nmask,....,X,Y,....]
994
- for i in range(axis+2,len(x.shape)):
995
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1873
+ if calc_var:
1874
+ if self.backend.bk_is_complex(vtmp):
1875
+ res2 = self.backend.bk_sqrt(
1876
+ (
1877
+ (
1878
+ self.backend.bk_real(v2) / self.backend.bk_real(vh)
1879
+ - self.backend.bk_real(res) * self.backend.bk_real(res)
1880
+ )
1881
+ + (
1882
+ self.backend.bk_imag(v2) / self.backend.bk_real(vh)
1883
+ - self.backend.bk_imag(res) * self.backend.bk_imag(res)
1884
+ )
1885
+ )
1886
+ / self.backend.bk_real(vh)
1887
+ )
1888
+ else:
1889
+ res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
996
1890
 
997
- shape1=list(l_mask.shape)
998
- shape2=list(l_x.shape)
1891
+ res = self.backend.bk_reshape(res, oshape)
1892
+ res2 = self.backend.bk_reshape(res2, oshape)
1893
+ return res, res2
1894
+ else:
1895
+ res = self.backend.bk_reshape(res, oshape)
1896
+ return res
1897
+
1898
+ elif self.use_1D:
1899
+ mtmp = l_mask
1900
+ vtmp = l_x
1901
+
1902
+ v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=2)
1903
+ v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2)
1904
+ vh = self.backend.bk_reduce_sum(mtmp, axis=2)
999
1905
 
1000
- oshape1=shape1[0:axis+1]+[shape1[axis+1]*shape1[axis+2]]+shape1[axis+3:]
1001
- oshape2=shape2[0:axis+1]+[shape2[axis+1]*shape2[axis+2]]+shape2[axis+3:]
1002
-
1003
- mtmp=self.backend.bk_reshape(l_mask,oshape1)
1004
- vtmp=self.backend.bk_reshape(l_x,oshape2)
1906
+ res = v1 / vh
1005
1907
 
1006
- v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=axis+1)
1007
- v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=axis+1)
1008
- vh=self.backend.bk_reduce_sum(mtmp,axis=axis+1)
1908
+ oshape = []
1909
+ if axis > 0:
1910
+ oshape = oshape + list(x.shape[0:axis])
1911
+ oshape = oshape + [mask.shape[0]]
1912
+ if axis + 1 < len(x.shape):
1913
+ oshape = oshape + list(x.shape[axis + 1 :])
1009
1914
 
1010
- res=v1/vh
1011
1915
  if calc_var:
1012
1916
  if self.backend.bk_is_complex(vtmp):
1013
- res2=self.backend.bk_complex(self.backend.bk_sqrt(self.backend.bk_real(v2)/self.backend.bk_real(vh)
1014
- -self.backend.bk_real(res)*self.backend.bk_real(res)), \
1015
- self.backend.bk_sqrt(self.backend.bk_imag(v2)/self.backend.bk_real(vh)
1016
- -self.backend.bk_imag(res)*self.backend.bk_imag(res)))
1917
+ res2 = self.backend.bk_sqrt(
1918
+ (
1919
+ (
1920
+ self.backend.bk_real(v2) / self.backend.bk_real(vh)
1921
+ - self.backend.bk_real(res) * self.backend.bk_real(res)
1922
+ )
1923
+ + (
1924
+ self.backend.bk_imag(v2) / self.backend.bk_real(vh)
1925
+ - self.backend.bk_imag(res) * self.backend.bk_imag(res)
1926
+ )
1927
+ )
1928
+ / self.backend.bk_real(vh)
1929
+ )
1017
1930
  else:
1018
- res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1019
- return res,res2
1931
+ res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1932
+
1933
+ res = self.backend.bk_reshape(res, oshape)
1934
+ res2 = self.backend.bk_reshape(res2, oshape)
1935
+ return res, res2
1020
1936
  else:
1937
+ res = self.backend.bk_reshape(res, oshape)
1021
1938
  return res
1939
+
1022
1940
  else:
1023
- # mask=[1,Nmask,....,X] => mask=[1,Nmask,....,X,....]
1024
- for i in range(axis+1,len(x.shape)):
1025
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1026
-
1027
- v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=axis+1)
1028
- v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=axis+1)
1029
- vh=self.backend.bk_reduce_sum(l_mask,axis=axis+1)
1030
-
1031
- res=v1/vh
1941
+ v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=2)
1942
+ v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=2)
1943
+ vh = self.backend.bk_reduce_sum(l_mask, axis=2)
1944
+
1945
+ res = v1 / vh
1946
+
1947
+ oshape = []
1948
+ if axis > 0:
1949
+ oshape = oshape + list(x.shape[0:axis])
1950
+ oshape = oshape + [mask.shape[0]]
1951
+ if axis + 1 < len(x.shape):
1952
+ oshape = oshape + list(x.shape[axis + 1 :])
1953
+
1032
1954
  if calc_var:
1033
1955
  if self.backend.bk_is_complex(l_x):
1034
- res2=self.backend.bk_complex(self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1035
- -self.backend.bk_real(res)*self.backend.bk_real(res))/self.backend.bk_real(v2)), \
1036
- self.backend.bk_sqrt((self.backend.bk_imag(v2)/self.backend.bk_real(vh)
1037
- -self.backend.bk_imag(res)*self.backend.bk_imag(res))/self.backend.bk_real(v2)))
1956
+ res2 = self.backend.bk_sqrt(
1957
+ (
1958
+ self.backend.bk_real(v2) / self.backend.bk_real(vh)
1959
+ - self.backend.bk_real(res) * self.backend.bk_real(res)
1960
+ + self.backend.bk_imag(v2) / self.backend.bk_real(vh)
1961
+ - self.backend.bk_imag(res) * self.backend.bk_imag(res)
1962
+ )
1963
+ / self.backend.bk_real(vh)
1964
+ )
1038
1965
  else:
1039
- res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1040
- return res,res2
1966
+ res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1967
+
1968
+ res = self.backend.bk_reshape(res, oshape)
1969
+ res2 = self.backend.bk_reshape(res2, oshape)
1970
+ return res, res2
1041
1971
  else:
1972
+ res = self.backend.bk_reshape(res, oshape)
1042
1973
  return res
1043
-
1974
+
1044
1975
  # ---------------------------------------------−---------
1045
1976
  # convert tensor x [....,a,b,....] to [....,a*b,....]
1046
- def reduce_dim(self,x,axis=0):
1047
- shape=list(x.shape)
1048
-
1049
- if axis<0:
1050
- laxis=len(shape)+axis
1977
+ def reduce_dim(self, x, axis=0):
1978
+ shape = list(x.shape)
1979
+
1980
+ if axis < 0:
1981
+ laxis = len(shape) + axis
1051
1982
  else:
1052
- laxis=axis
1053
-
1054
- if laxis>0 :
1055
- oshape=shape[0:laxis]
1056
- oshape.append(shape[laxis]*shape[laxis+1])
1983
+ laxis = axis
1984
+
1985
+ if laxis > 0:
1986
+ oshape = shape[0:laxis]
1987
+ oshape.append(shape[laxis] * shape[laxis + 1])
1057
1988
  else:
1058
- oshape=[shape[laxis]*shape[laxis+1]]
1059
-
1060
- if laxis<len(shape)-1:
1061
- oshape.extend(shape[laxis+2:])
1062
-
1063
- return(self.backend.bk_reshape(x,oshape))
1064
-
1065
-
1989
+ oshape = [shape[laxis] * shape[laxis + 1]]
1990
+
1991
+ if laxis < len(shape) - 1:
1992
+ oshape.extend(shape[laxis + 2 :])
1993
+
1994
+ return self.backend.bk_reshape(x, oshape)
1995
+
1066
1996
  # ---------------------------------------------−---------
1067
- def conv2d(self,image,ww,axis=0):
1997
+ def conv2d(self, image, ww, axis=0):
1068
1998
 
1069
- if len(ww.shape)==2:
1070
- norient=ww.shape[1]
1999
+ if len(ww.shape) == 2:
2000
+ norient = ww.shape[1]
1071
2001
  else:
1072
- norient=ww.shape[2]
2002
+ norient = ww.shape[2]
1073
2003
 
1074
- shape=image.shape
2004
+ shape = image.shape
1075
2005
 
1076
- if axis>0:
1077
- o_shape=shape[0]
1078
- for k in range(1,axis+1):
1079
- o_shape=o_shape*shape[k]
2006
+ if axis > 0:
2007
+ o_shape = shape[0]
2008
+ for k in range(1, axis + 1):
2009
+ o_shape = o_shape * shape[k]
1080
2010
  else:
1081
- o_shape=image.shape[0]
1082
-
1083
- if len(shape)>axis+3:
1084
- ishape=shape[axis+3]
1085
- for k in range(axis+4,len(shape)):
1086
- ishape=ishape*shape[k]
1087
-
1088
- oshape=[o_shape,shape[axis+1],shape[axis+2],ishape]
1089
-
1090
- #l_image=self.swapaxes(self.bk_reshape(image,oshape),-1,-3)
1091
- l_image=self.backend.bk_reshape(image,oshape)
1092
-
1093
- l_ww=np.zeros([self.KERNELSZ,self.KERNELSZ,ishape,ishape*norient])
2011
+ o_shape = image.shape[0]
2012
+
2013
+ if len(shape) > axis + 3:
2014
+ ishape = shape[axis + 3]
2015
+ for k in range(axis + 4, len(shape)):
2016
+ ishape = ishape * shape[k]
2017
+
2018
+ oshape = [o_shape, shape[axis + 1], shape[axis + 2], ishape]
2019
+
2020
+ # l_image=self.swapaxes(self.bk_reshape(image,oshape),-1,-3)
2021
+ l_image = self.backend.bk_reshape(image, oshape)
2022
+
2023
+ l_ww = np.zeros([self.KERNELSZ, self.KERNELSZ, ishape, ishape * norient])
1094
2024
  for k in range(ishape):
1095
- l_ww[:,:,k,k*norient:(k+1)*norient]=ww.reshape(self.KERNELSZ,self.KERNELSZ,norient)
1096
-
2025
+ l_ww[:, :, k, k * norient : (k + 1) * norient] = ww.reshape(
2026
+ self.KERNELSZ, self.KERNELSZ, norient
2027
+ )
2028
+
1097
2029
  if self.backend.bk_is_complex(l_image):
1098
- r=self.backend.conv2d(self.backend.bk_real(l_image),
1099
- l_ww,
1100
- strides=[1, 1, 1, 1],
1101
- padding=self.padding)
1102
- i=self.backend.conv2d(self.backend.bk_imag(l_image),
1103
- l_ww,
1104
- strides=[1, 1, 1, 1],
1105
- padding=self.padding)
1106
- res=self.backend.bk_complex(r,i)
2030
+ r = self.backend.conv2d(
2031
+ self.backend.bk_real(l_image),
2032
+ l_ww,
2033
+ strides=[1, 1, 1, 1],
2034
+ padding=self.padding,
2035
+ )
2036
+ i = self.backend.conv2d(
2037
+ self.backend.bk_imag(l_image),
2038
+ l_ww,
2039
+ strides=[1, 1, 1, 1],
2040
+ padding=self.padding,
2041
+ )
2042
+ res = self.backend.bk_complex(r, i)
1107
2043
  else:
1108
- res=self.backend.conv2d(l_image,l_ww,strides=[1, 1, 1, 1],padding=self.padding)
2044
+ res = self.backend.conv2d(
2045
+ l_image, l_ww, strides=[1, 1, 1, 1], padding=self.padding
2046
+ )
1109
2047
 
1110
- res=self.backend.bk_reshape(res,[o_shape,shape[axis+1],shape[axis+2],ishape,norient])
2048
+ res = self.backend.bk_reshape(
2049
+ res, [o_shape, shape[axis + 1], shape[axis + 2], ishape, norient]
2050
+ )
1111
2051
  else:
1112
- oshape=[o_shape,shape[axis+1],shape[axis+2],1]
1113
- l_ww=self.backend.bk_reshape(ww,[self.KERNELSZ,self.KERNELSZ,1,norient])
2052
+ oshape = [o_shape, shape[axis + 1], shape[axis + 2], 1]
2053
+ l_ww = self.backend.bk_reshape(
2054
+ ww, [self.KERNELSZ, self.KERNELSZ, 1, norient]
2055
+ )
1114
2056
 
1115
- tmp=self.backend.bk_reshape(image,oshape)
2057
+ tmp = self.backend.bk_reshape(image, oshape)
1116
2058
  if self.backend.bk_is_complex(tmp):
1117
- r=self.backend.conv2d(self.backend.bk_real(tmp),
1118
- l_ww,
1119
- strides=[1, 1, 1, 1],
1120
- padding=self.padding)
1121
- i=self.backend.conv2d(self.backend.bk_imag(tmp),
1122
- l_ww,
1123
- strides=[1, 1, 1, 1],
1124
- padding=self.padding)
1125
- res=self.backend.bk_complex(r,i)
2059
+ r = self.backend.conv2d(
2060
+ self.backend.bk_real(tmp),
2061
+ l_ww,
2062
+ strides=[1, 1, 1, 1],
2063
+ padding=self.padding,
2064
+ )
2065
+ i = self.backend.conv2d(
2066
+ self.backend.bk_imag(tmp),
2067
+ l_ww,
2068
+ strides=[1, 1, 1, 1],
2069
+ padding=self.padding,
2070
+ )
2071
+ res = self.backend.bk_complex(r, i)
2072
+ else:
2073
+ res = self.backend.conv2d(
2074
+ tmp, l_ww, strides=[1, 1, 1, 1], padding=self.padding
2075
+ )
2076
+
2077
+ return self.backend.bk_reshape(res, shape + [norient])
2078
+
2079
+ def diff_data(self, x, y, is_complex=True, sigma=None):
2080
+ if sigma is None:
2081
+ if self.backend.bk_is_complex(x):
2082
+ r = self.backend.bk_square(
2083
+ self.backend.bk_real(x) - self.backend.bk_real(y)
2084
+ )
2085
+ i = self.backend.bk_square(
2086
+ self.backend.bk_imag(x) - self.backend.bk_imag(y)
2087
+ )
2088
+ return self.backend.bk_reduce_sum(r + i)
2089
+ else:
2090
+ r = self.backend.bk_square(x - y)
2091
+ return self.backend.bk_reduce_sum(r)
2092
+ else:
2093
+ if self.backend.bk_is_complex(x):
2094
+ r = self.backend.bk_square(
2095
+ (self.backend.bk_real(x) - self.backend.bk_real(y)) / sigma
2096
+ )
2097
+ i = self.backend.bk_square(
2098
+ (self.backend.bk_imag(x) - self.backend.bk_imag(y)) / sigma
2099
+ )
2100
+ return self.backend.bk_reduce_sum(r + i)
1126
2101
  else:
1127
- res=self.backend.conv2d(tmp,
1128
- l_ww,
1129
- strides=[1, 1, 1, 1],
1130
- padding=self.padding)
2102
+ r = self.backend.bk_square((x - y) / sigma)
2103
+ return self.backend.bk_reduce_sum(r)
1131
2104
 
1132
- return self.backend.bk_reshape(res,shape+[norient])
1133
-
1134
2105
  # ---------------------------------------------−---------
1135
- def convol(self,in_image,axis=0):
2106
+ def convol(self, in_image, axis=0):
2107
+
2108
+ image = self.backend.bk_cast(in_image)
1136
2109
 
1137
- image=self.backend.bk_cast(in_image)
1138
-
1139
2110
  if self.use_2D:
1140
-
1141
- ishape=list(in_image.shape)
1142
- if len(ishape)<axis+2:
1143
- print('Use of 2D scat with data that has less than 2D')
1144
- exit(0)
1145
-
1146
- npix=ishape[axis]
1147
- npiy=ishape[axis+1]
1148
- odata=1
1149
- if len(ishape)>axis+2:
1150
- for k in range(axis+2,len(ishape)):
1151
- odata=odata*ishape[k]
1152
-
1153
- ndata=1
2111
+ ishape = list(in_image.shape)
2112
+ if len(ishape) < axis + 2:
2113
+ if not self.silent:
2114
+ print("Use of 2D scat with data that has less than 2D")
2115
+ return None
2116
+
2117
+ npix = ishape[axis]
2118
+ npiy = ishape[axis + 1]
2119
+ odata = 1
2120
+ if len(ishape) > axis + 2:
2121
+ for k in range(axis + 2, len(ishape)):
2122
+ odata = odata * ishape[k]
2123
+
2124
+ ndata = 1
1154
2125
  for k in range(axis):
1155
- ndata=ndata*ishape[k]
2126
+ ndata = ndata * ishape[k]
1156
2127
 
1157
- tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,npiy,odata])
2128
+ tim = self.backend.bk_reshape(
2129
+ self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2130
+ )
1158
2131
 
1159
2132
  if self.backend.bk_is_complex(tim):
1160
- rr1=self.backend.conv2d(self.backend.bk_real(tim),self.ww_RealT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1161
- ii1=self.backend.conv2d(self.backend.bk_real(tim),self.ww_ImagT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1162
- rr2=self.backend.conv2d(self.backend.bk_imag(tim),self.ww_RealT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1163
- ii2=self.backend.conv2d(self.backend.bk_imag(tim),self.ww_ImagT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1164
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2133
+ rr1 = self.backend.conv2d(
2134
+ self.backend.bk_real(tim),
2135
+ self.ww_RealT[odata],
2136
+ strides=[1, 1, 1, 1],
2137
+ padding=self.padding,
2138
+ )
2139
+ ii1 = self.backend.conv2d(
2140
+ self.backend.bk_real(tim),
2141
+ self.ww_ImagT[odata],
2142
+ strides=[1, 1, 1, 1],
2143
+ padding=self.padding,
2144
+ )
2145
+ rr2 = self.backend.conv2d(
2146
+ self.backend.bk_imag(tim),
2147
+ self.ww_RealT[odata],
2148
+ strides=[1, 1, 1, 1],
2149
+ padding=self.padding,
2150
+ )
2151
+ ii2 = self.backend.conv2d(
2152
+ self.backend.bk_imag(tim),
2153
+ self.ww_ImagT[odata],
2154
+ strides=[1, 1, 1, 1],
2155
+ padding=self.padding,
2156
+ )
2157
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2158
+ else:
2159
+ rr = self.backend.conv2d(
2160
+ tim,
2161
+ self.ww_RealT[odata],
2162
+ strides=[1, 1, 1, 1],
2163
+ padding=self.padding,
2164
+ )
2165
+ ii = self.backend.conv2d(
2166
+ tim,
2167
+ self.ww_ImagT[odata],
2168
+ strides=[1, 1, 1, 1],
2169
+ padding=self.padding,
2170
+ )
2171
+ res = self.backend.bk_complex(rr, ii)
2172
+
2173
+ if axis == 0:
2174
+ if len(ishape) == 2:
2175
+ return self.backend.bk_reshape(
2176
+ res, [res.shape[1], res.shape[2], self.NORIENT]
2177
+ )
2178
+ else:
2179
+ return self.backend.bk_reshape(
2180
+ res,
2181
+ [res.shape[1], res.shape[2], self.NORIENT] + ishape[axis + 2 :],
2182
+ )
1165
2183
  else:
1166
- rr=self.backend.conv2d(tim,self.ww_RealT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1167
- ii=self.backend.conv2d(tim,self.ww_ImagT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1168
- res=self.backend.bk_complex(rr,ii)
1169
-
1170
- if axis==0:
1171
- if len(ishape)==2:
1172
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2],self.NORIENT])
2184
+ if len(ishape) == axis + 2:
2185
+ return self.backend.bk_reshape(
2186
+ res, ishape[0:axis] + [res.shape[1], res.shape[2], self.NORIENT]
2187
+ )
1173
2188
  else:
1174
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
2189
+ return self.backend.bk_reshape(
2190
+ res,
2191
+ ishape[0:axis]
2192
+ + [res.shape[1], res.shape[2], self.NORIENT]
2193
+ + ishape[axis + 2 :],
2194
+ )
2195
+
2196
+ return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2197
+ elif self.use_1D:
2198
+ ishape = list(in_image.shape)
2199
+ if len(ishape) < axis + 1:
2200
+ if not self.silent:
2201
+ print("Use of 1D scat with data that has less than 1D")
2202
+ return None
2203
+
2204
+ npix = ishape[axis]
2205
+ odata = 1
2206
+ if len(ishape) > axis + 1:
2207
+ for k in range(axis + 1, len(ishape)):
2208
+ odata = odata * ishape[k]
2209
+
2210
+ ndata = 1
2211
+ for k in range(axis):
2212
+ ndata = ndata * ishape[k]
2213
+
2214
+ tim = self.backend.bk_reshape(
2215
+ self.backend.bk_cast(in_image), [ndata, npix, odata]
2216
+ )
2217
+
2218
+ if self.backend.bk_is_complex(tim):
2219
+ rr1 = self.backend.conv1d(
2220
+ self.backend.bk_real(tim),
2221
+ self.ww_RealT[odata],
2222
+ strides=[1, 1, 1],
2223
+ padding=self.padding,
2224
+ )
2225
+ ii1 = self.backend.conv1d(
2226
+ self.backend.bk_real(tim),
2227
+ self.ww_ImagT[odata],
2228
+ strides=[1, 1, 1],
2229
+ padding=self.padding,
2230
+ )
2231
+ rr2 = self.backend.conv1d(
2232
+ self.backend.bk_imag(tim),
2233
+ self.ww_RealT[odata],
2234
+ strides=[1, 1, 1],
2235
+ padding=self.padding,
2236
+ )
2237
+ ii2 = self.backend.conv1d(
2238
+ self.backend.bk_imag(tim),
2239
+ self.ww_ImagT[odata],
2240
+ strides=[1, 1, 1],
2241
+ padding=self.padding,
2242
+ )
2243
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1175
2244
  else:
1176
- if len(ishape)==axis+2:
1177
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT])
2245
+ rr = self.backend.conv1d(
2246
+ tim, self.ww_RealT[odata], strides=[1, 1, 1], padding=self.padding
2247
+ )
2248
+ ii = self.backend.conv1d(
2249
+ tim, self.ww_ImagT[odata], strides=[1, 1, 1], padding=self.padding
2250
+ )
2251
+ res = self.backend.bk_complex(rr, ii)
2252
+
2253
+ if axis == 0:
2254
+ if len(ishape) == 1:
2255
+ return self.backend.bk_reshape(res, [res.shape[1]])
1178
2256
  else:
1179
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
1180
-
1181
- return self.backend.bk_reshape(res,[nout,nouty])
1182
-
2257
+ return self.backend.bk_reshape(
2258
+ res, [res.shape[1]] + ishape[axis + 2 :]
2259
+ )
2260
+ else:
2261
+ if len(ishape) == axis + 1:
2262
+ return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
2263
+ else:
2264
+ return self.backend.bk_reshape(
2265
+ res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2266
+ )
2267
+
2268
+ return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2269
+
1183
2270
  else:
1184
- nside=int(np.sqrt(image.shape[axis]//12))
2271
+ nside = int(np.sqrt(image.shape[axis] // 12))
1185
2272
 
1186
2273
  if self.Idx_Neighbours[nside] is None:
1187
2274
  if self.InitWave is None:
1188
- wr,wi,ws,widx=self.init_index(nside)
2275
+ wr, wi, ws, widx = self.init_index(nside)
1189
2276
  else:
1190
- wr,wi,ws,widx=self.InitWave(self,nside)
1191
-
1192
- self.Idx_Neighbours[nside]=1 #self.backend.constant(tmp)
1193
- self.ww_Real[nside]=wr
1194
- self.ww_Imag[nside]=wi
1195
- self.w_smooth[nside]=ws
1196
-
1197
- l_ww_real=self.ww_Real[nside]
1198
- l_ww_imag=self.ww_Imag[nside]
1199
-
1200
- ishape=list(image.shape)
1201
- odata=1
1202
- for k in range(axis+1,len(ishape)):
1203
- odata=odata*ishape[k]
1204
-
1205
- if axis>0:
1206
- ndata=1
2277
+ wr, wi, ws, widx = self.InitWave(self, nside)
2278
+
2279
+ self.Idx_Neighbours[nside] = 1 # self.backend.constant(tmp)
2280
+ self.ww_Real[nside] = wr
2281
+ self.ww_Imag[nside] = wi
2282
+ self.w_smooth[nside] = ws
2283
+
2284
+ l_ww_real = self.ww_Real[nside]
2285
+ l_ww_imag = self.ww_Imag[nside]
2286
+
2287
+ ishape = list(image.shape)
2288
+ odata = 1
2289
+ for k in range(axis + 1, len(ishape)):
2290
+ odata = odata * ishape[k]
2291
+
2292
+ if axis > 0:
2293
+ ndata = 1
1207
2294
  for k in range(axis):
1208
- ndata=ndata*ishape[k]
1209
- tim=self.backend.bk_reshape(self.backend.bk_cast(image),[ndata,12*nside**2,odata])
1210
- if tim.dtype==self.all_cbk_type:
1211
- rr1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_real(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1212
- ii1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_real(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1213
- rr2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_imag(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1214
- ii2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_imag(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1215
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2295
+ ndata = ndata * ishape[k]
2296
+ tim = self.backend.bk_reshape(
2297
+ self.backend.bk_cast(image), [ndata, 12 * nside**2, odata]
2298
+ )
2299
+ if tim.dtype == self.all_cbk_type:
2300
+ rr1 = self.backend.bk_reshape(
2301
+ self.backend.bk_sparse_dense_matmul(
2302
+ l_ww_real, self.backend.bk_real(tim[0])
2303
+ ),
2304
+ [1, 12 * nside**2, self.NORIENT, odata],
2305
+ )
2306
+ ii1 = self.backend.bk_reshape(
2307
+ self.backend.bk_sparse_dense_matmul(
2308
+ l_ww_imag, self.backend.bk_real(tim[0])
2309
+ ),
2310
+ [1, 12 * nside**2, self.NORIENT, odata],
2311
+ )
2312
+ rr2 = self.backend.bk_reshape(
2313
+ self.backend.bk_sparse_dense_matmul(
2314
+ l_ww_real, self.backend.bk_imag(tim[0])
2315
+ ),
2316
+ [1, 12 * nside**2, self.NORIENT, odata],
2317
+ )
2318
+ ii2 = self.backend.bk_reshape(
2319
+ self.backend.bk_sparse_dense_matmul(
2320
+ l_ww_imag, self.backend.bk_imag(tim[0])
2321
+ ),
2322
+ [1, 12 * nside**2, self.NORIENT, odata],
2323
+ )
2324
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1216
2325
  else:
1217
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,tim[0]),[1,12*nside**2,self.NORIENT,odata])
1218
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,tim[0]),[1,12*nside**2,self.NORIENT,odata])
1219
- res=self.backend.bk_complex(rr,ii)
1220
-
1221
- for k in range(1,ndata):
1222
- if tim.dtype==self.all_cbk_type:
1223
- rr1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_real(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1224
- ii1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_real(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1225
- rr2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_imag(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1226
- ii2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_imag(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1227
- res=self.backend.bk_concat([res,self.backend.bk_complex(rr1-ii2,ii1+rr2)],0)
2326
+ rr = self.backend.bk_reshape(
2327
+ self.backend.bk_sparse_dense_matmul(l_ww_real, tim[0]),
2328
+ [1, 12 * nside**2, self.NORIENT, odata],
2329
+ )
2330
+ ii = self.backend.bk_reshape(
2331
+ self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[0]),
2332
+ [1, 12 * nside**2, self.NORIENT, odata],
2333
+ )
2334
+ res = self.backend.bk_complex(rr, ii)
2335
+
2336
+ for k in range(1, ndata):
2337
+ if tim.dtype == self.all_cbk_type:
2338
+ rr1 = self.backend.bk_reshape(
2339
+ self.backend.bk_sparse_dense_matmul(
2340
+ l_ww_real, self.backend.bk_real(tim[k])
2341
+ ),
2342
+ [1, 12 * nside**2, self.NORIENT, odata],
2343
+ )
2344
+ ii1 = self.backend.bk_reshape(
2345
+ self.backend.bk_sparse_dense_matmul(
2346
+ l_ww_imag, self.backend.bk_real(tim[k])
2347
+ ),
2348
+ [1, 12 * nside**2, self.NORIENT, odata],
2349
+ )
2350
+ rr2 = self.backend.bk_reshape(
2351
+ self.backend.bk_sparse_dense_matmul(
2352
+ l_ww_real, self.backend.bk_imag(tim[k])
2353
+ ),
2354
+ [1, 12 * nside**2, self.NORIENT, odata],
2355
+ )
2356
+ ii2 = self.backend.bk_reshape(
2357
+ self.backend.bk_sparse_dense_matmul(
2358
+ l_ww_imag, self.backend.bk_imag(tim[k])
2359
+ ),
2360
+ [1, 12 * nside**2, self.NORIENT, odata],
2361
+ )
2362
+ res = self.backend.bk_concat(
2363
+ [res, self.backend.bk_complex(rr1 - ii2, ii1 + rr2)], 0
2364
+ )
1228
2365
  else:
1229
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,tim[k]),[1,12*nside**2,self.NORIENT,odata])
1230
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,tim[k]),[1,12*nside**2,self.NORIENT,odata])
1231
- res=self.backend.bk_concat([res,self.backend.bk_complex(rr,ii)],0)
1232
-
1233
- if len(ishape)==axis+1:
1234
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2,self.NORIENT])
2366
+ rr = self.backend.bk_reshape(
2367
+ self.backend.bk_sparse_dense_matmul(l_ww_real, tim[k]),
2368
+ [1, 12 * nside**2, self.NORIENT, odata],
2369
+ )
2370
+ ii = self.backend.bk_reshape(
2371
+ self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[k]),
2372
+ [1, 12 * nside**2, self.NORIENT, odata],
2373
+ )
2374
+ res = self.backend.bk_concat(
2375
+ [res, self.backend.bk_complex(rr, ii)], 0
2376
+ )
2377
+
2378
+ if len(ishape) == axis + 1:
2379
+ return self.backend.bk_reshape(
2380
+ res, ishape[0:axis] + [12 * nside**2, self.NORIENT]
2381
+ )
1235
2382
  else:
1236
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2]+ishape[axis+1:]+[self.NORIENT])
1237
-
1238
- if axis==0:
1239
- tim=self.backend.bk_reshape(self.backend.bk_cast(image),[12*nside**2,odata])
1240
- if tim.dtype==self.all_cbk_type:
1241
- rr1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_real(tim)),[12*nside**2,self.NORIENT,odata])
1242
- ii1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_real(tim)),[12*nside**2,self.NORIENT,odata])
1243
- rr2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_imag(tim)),[12*nside**2,self.NORIENT,odata])
1244
- ii2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_imag(tim)),[12*nside**2,self.NORIENT,odata])
1245
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2383
+ return self.backend.bk_reshape(
2384
+ res,
2385
+ ishape[0:axis]
2386
+ + [12 * nside**2]
2387
+ + ishape[axis + 1 :]
2388
+ + [self.NORIENT],
2389
+ )
2390
+
2391
+ if axis == 0:
2392
+ tim = self.backend.bk_reshape(
2393
+ self.backend.bk_cast(image), [12 * nside**2, odata]
2394
+ )
2395
+ if tim.dtype == self.all_cbk_type:
2396
+ rr1 = self.backend.bk_reshape(
2397
+ self.backend.bk_sparse_dense_matmul(
2398
+ l_ww_real, self.backend.bk_real(tim)
2399
+ ),
2400
+ [12 * nside**2, self.NORIENT, odata],
2401
+ )
2402
+ ii1 = self.backend.bk_reshape(
2403
+ self.backend.bk_sparse_dense_matmul(
2404
+ l_ww_imag, self.backend.bk_real(tim)
2405
+ ),
2406
+ [12 * nside**2, self.NORIENT, odata],
2407
+ )
2408
+ rr2 = self.backend.bk_reshape(
2409
+ self.backend.bk_sparse_dense_matmul(
2410
+ l_ww_real, self.backend.bk_imag(tim)
2411
+ ),
2412
+ [12 * nside**2, self.NORIENT, odata],
2413
+ )
2414
+ ii2 = self.backend.bk_reshape(
2415
+ self.backend.bk_sparse_dense_matmul(
2416
+ l_ww_imag, self.backend.bk_imag(tim)
2417
+ ),
2418
+ [12 * nside**2, self.NORIENT, odata],
2419
+ )
2420
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1246
2421
  else:
1247
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,tim),[12*nside**2,self.NORIENT,odata])
1248
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,tim),[12*nside**2,self.NORIENT,odata])
1249
- res=self.backend.bk_complex(rr,ii)
1250
-
1251
- if len(ishape)==1:
1252
- return self.backend.bk_reshape(res,[12*nside**2,self.NORIENT])
2422
+ rr = self.backend.bk_reshape(
2423
+ self.backend.bk_sparse_dense_matmul(l_ww_real, tim),
2424
+ [12 * nside**2, self.NORIENT, odata],
2425
+ )
2426
+ ii = self.backend.bk_reshape(
2427
+ self.backend.bk_sparse_dense_matmul(l_ww_imag, tim),
2428
+ [12 * nside**2, self.NORIENT, odata],
2429
+ )
2430
+ res = self.backend.bk_complex(rr, ii)
2431
+
2432
+ if len(ishape) == 1:
2433
+ return self.backend.bk_reshape(res, [12 * nside**2, self.NORIENT])
1253
2434
  else:
1254
- return self.backend.bk_reshape(res,[12*nside**2]+ishape[axis+1:]+[self.NORIENT])
1255
- return(res)
1256
-
2435
+ return self.backend.bk_reshape(
2436
+ res, [12 * nside**2] + ishape[axis + 1 :] + [self.NORIENT]
2437
+ )
2438
+ return res
1257
2439
 
1258
2440
  # ---------------------------------------------−---------
1259
- def smooth(self,in_image,axis=0):
2441
+ def smooth(self, in_image, axis=0):
2442
+
2443
+ image = self.backend.bk_cast(in_image)
1260
2444
 
1261
- image=self.backend.bk_cast(in_image)
1262
-
1263
2445
  if self.use_2D:
1264
-
1265
- ishape=list(in_image.shape)
1266
- if len(ishape)<axis+2:
1267
- print('Use of 2D scat with data that has less than 2D')
1268
- exit(0)
1269
-
1270
- npix=ishape[axis]
1271
- npiy=ishape[axis+1]
1272
- odata=1
1273
- if len(ishape)>axis+2:
1274
- for k in range(axis+2,len(ishape)):
1275
- odata=odata*ishape[k]
1276
-
1277
- ndata=1
2446
+
2447
+ ishape = list(in_image.shape)
2448
+ if len(ishape) < axis + 2:
2449
+ if not self.silent:
2450
+ print("Use of 2D scat with data that has less than 2D")
2451
+ return None
2452
+
2453
+ npix = ishape[axis]
2454
+ npiy = ishape[axis + 1]
2455
+ odata = 1
2456
+ if len(ishape) > axis + 2:
2457
+ for k in range(axis + 2, len(ishape)):
2458
+ odata = odata * ishape[k]
2459
+
2460
+ ndata = 1
1278
2461
  for k in range(axis):
1279
- ndata=ndata*ishape[k]
2462
+ ndata = ndata * ishape[k]
1280
2463
 
1281
- tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,npiy,odata])
2464
+ tim = self.backend.bk_reshape(
2465
+ self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2466
+ )
1282
2467
 
1283
2468
  if self.backend.bk_is_complex(tim):
1284
- rr=self.backend.conv2d(self.backend.bk_real(tim),self.ww_SmoothT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1285
- ii=self.backend.conv2d(self.backend.bk_imag(tim),self.ww_SmoothT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1286
- res=self.backend.bk_complex(rr,ii)
2469
+ rr = self.backend.conv2d(
2470
+ self.backend.bk_real(tim),
2471
+ self.ww_SmoothT[odata],
2472
+ strides=[1, 1, 1, 1],
2473
+ padding=self.padding,
2474
+ )
2475
+ ii = self.backend.conv2d(
2476
+ self.backend.bk_imag(tim),
2477
+ self.ww_SmoothT[odata],
2478
+ strides=[1, 1, 1, 1],
2479
+ padding=self.padding,
2480
+ )
2481
+ res = self.backend.bk_complex(rr, ii)
1287
2482
  else:
1288
- res=self.backend.conv2d(tim,self.ww_SmoothT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1289
-
1290
- if axis==0:
1291
- if len(ishape)==2:
1292
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2]])
2483
+ res = self.backend.conv2d(
2484
+ tim,
2485
+ self.ww_SmoothT[odata],
2486
+ strides=[1, 1, 1, 1],
2487
+ padding=self.padding,
2488
+ )
2489
+
2490
+ if axis == 0:
2491
+ if len(ishape) == 2:
2492
+ return self.backend.bk_reshape(res, [res.shape[1], res.shape[2]])
1293
2493
  else:
1294
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2]]+ishape[axis+2:])
2494
+ return self.backend.bk_reshape(
2495
+ res, [res.shape[1], res.shape[2]] + ishape[axis + 2 :]
2496
+ )
1295
2497
  else:
1296
- if len(ishape)==axis+2:
1297
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]])
2498
+ if len(ishape) == axis + 2:
2499
+ return self.backend.bk_reshape(
2500
+ res, ishape[0:axis] + [res.shape[1], res.shape[2]]
2501
+ )
1298
2502
  else:
1299
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]]+ishape[axis+2:])
1300
-
1301
- return self.backend.bk_reshape(res,[nout,nouty])
1302
-
2503
+ return self.backend.bk_reshape(
2504
+ res,
2505
+ ishape[0:axis]
2506
+ + [res.shape[1], res.shape[2]]
2507
+ + ishape[axis + 2 :],
2508
+ )
2509
+
2510
+ return self.backend.bk_reshape(res, in_image.shape)
2511
+ elif self.use_1D:
2512
+
2513
+ ishape = list(in_image.shape)
2514
+ if len(ishape) < axis + 1:
2515
+ if not self.silent:
2516
+ print("Use of 1D scat with data that has less than 1D")
2517
+ return None
2518
+
2519
+ npix = ishape[axis]
2520
+ odata = 1
2521
+ if len(ishape) > axis + 1:
2522
+ for k in range(axis + 1, len(ishape)):
2523
+ odata = odata * ishape[k]
2524
+
2525
+ ndata = 1
2526
+ for k in range(axis):
2527
+ ndata = ndata * ishape[k]
2528
+
2529
+ tim = self.backend.bk_reshape(
2530
+ self.backend.bk_cast(in_image), [ndata, npix, odata]
2531
+ )
2532
+
2533
+ if self.backend.bk_is_complex(tim):
2534
+ rr = self.backend.conv1d(
2535
+ self.backend.bk_real(tim),
2536
+ self.ww_SmoothT[odata],
2537
+ strides=[1, 1, 1],
2538
+ padding=self.padding,
2539
+ )
2540
+ ii = self.backend.conv1d(
2541
+ self.backend.bk_imag(tim),
2542
+ self.ww_SmoothT[odata],
2543
+ strides=[1, 1, 1],
2544
+ padding=self.padding,
2545
+ )
2546
+ res = self.backend.bk_complex(rr, ii)
2547
+ else:
2548
+ res = self.backend.conv1d(
2549
+ tim, self.ww_SmoothT[odata], strides=[1, 1, 1], padding=self.padding
2550
+ )
2551
+
2552
+ if axis == 0:
2553
+ if len(ishape) == 1:
2554
+ return self.backend.bk_reshape(res, [res.shape[1]])
2555
+ else:
2556
+ return self.backend.bk_reshape(
2557
+ res, [res.shape[1]] + ishape[axis + 1 :]
2558
+ )
2559
+ else:
2560
+ if len(ishape) == axis + 1:
2561
+ return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
2562
+ else:
2563
+ return self.backend.bk_reshape(
2564
+ res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2565
+ )
2566
+
2567
+ return self.backend.bk_reshape(res, in_image.shape)
2568
+
1303
2569
  else:
1304
- nside=int(np.sqrt(image.shape[axis]//12))
2570
+ nside = int(np.sqrt(image.shape[axis] // 12))
1305
2571
 
1306
2572
  if self.Idx_Neighbours[nside] is None:
1307
-
2573
+
1308
2574
  if self.InitWave is None:
1309
- wr,wi,ws,widx=self.init_index(nside)
2575
+ wr, wi, ws, widx = self.init_index(nside)
1310
2576
  else:
1311
- wr,wi,ws,widx=self.InitWave(self,nside)
1312
-
1313
- self.Idx_Neighbours[nside]=1
1314
- self.ww_Real[nside]=wr
1315
- self.ww_Imag[nside]=wi
1316
- self.w_smooth[nside]=ws
1317
-
1318
- l_w_smooth=self.w_smooth[nside]
1319
- ishape=list(image.shape)
1320
-
1321
- odata=1
1322
- for k in range(axis+1,len(ishape)):
1323
- odata=odata*ishape[k]
1324
-
1325
- if axis==0:
1326
- tim=self.backend.bk_reshape(image,[12*nside**2,odata])
1327
- if tim.dtype==self.all_cbk_type:
1328
- rr=self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_real(tim))
1329
- ri=self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_imag(tim))
1330
- res=self.backend.bk_complex(rr,ri)
2577
+ wr, wi, ws, widx = self.InitWave(self, nside)
2578
+
2579
+ self.Idx_Neighbours[nside] = 1
2580
+ self.ww_Real[nside] = wr
2581
+ self.ww_Imag[nside] = wi
2582
+ self.w_smooth[nside] = ws
2583
+
2584
+ l_w_smooth = self.w_smooth[nside]
2585
+ ishape = list(image.shape)
2586
+
2587
+ odata = 1
2588
+ for k in range(axis + 1, len(ishape)):
2589
+ odata = odata * ishape[k]
2590
+
2591
+ if axis == 0:
2592
+ tim = self.backend.bk_reshape(image, [12 * nside**2, odata])
2593
+ if tim.dtype == self.all_cbk_type:
2594
+ rr = self.backend.bk_sparse_dense_matmul(
2595
+ l_w_smooth, self.backend.bk_real(tim)
2596
+ )
2597
+ ri = self.backend.bk_sparse_dense_matmul(
2598
+ l_w_smooth, self.backend.bk_imag(tim)
2599
+ )
2600
+ res = self.backend.bk_complex(rr, ri)
1331
2601
  else:
1332
- res=self.backend.bk_sparse_dense_matmul(l_w_smooth,tim)
1333
- if len(ishape)==1:
1334
- return self.backend.bk_reshape(res,[12*nside**2])
2602
+ res = self.backend.bk_sparse_dense_matmul(l_w_smooth, tim)
2603
+ if len(ishape) == 1:
2604
+ return self.backend.bk_reshape(res, [12 * nside**2])
1335
2605
  else:
1336
- return self.backend.bk_reshape(res,[12*nside**2]+ishape[axis+1:])
1337
-
1338
- if axis>0:
1339
- ndata=ishape[0]
1340
- for k in range(1,axis):
1341
- ndata=ndata*ishape[k]
1342
- tim=self.backend.bk_reshape(image,[ndata,12*nside**2,odata])
1343
- if tim.dtype==self.all_cbk_type:
1344
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_real(tim[0])),[1,12*nside**2,odata])
1345
- ri=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_imag(tim[0])),[1,12*nside**2,odata])
1346
- res=self.backend.bk_complex(rr,ri)
2606
+ return self.backend.bk_reshape(
2607
+ res, [12 * nside**2] + ishape[axis + 1 :]
2608
+ )
2609
+
2610
+ if axis > 0:
2611
+ ndata = ishape[0]
2612
+ for k in range(1, axis):
2613
+ ndata = ndata * ishape[k]
2614
+ tim = self.backend.bk_reshape(image, [ndata, 12 * nside**2, odata])
2615
+ if tim.dtype == self.all_cbk_type:
2616
+ rr = self.backend.bk_reshape(
2617
+ self.backend.bk_sparse_dense_matmul(
2618
+ l_w_smooth, self.backend.bk_real(tim[0])
2619
+ ),
2620
+ [1, 12 * nside**2, odata],
2621
+ )
2622
+ ri = self.backend.bk_reshape(
2623
+ self.backend.bk_sparse_dense_matmul(
2624
+ l_w_smooth, self.backend.bk_imag(tim[0])
2625
+ ),
2626
+ [1, 12 * nside**2, odata],
2627
+ )
2628
+ res = self.backend.bk_complex(rr, ri)
1347
2629
  else:
1348
- res=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,tim[0]),[1,12*nside**2,odata])
1349
-
1350
- for k in range(1,ndata):
1351
- if tim.dtype==self.all_cbk_type:
1352
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_real(tim[k])),[1,12*nside**2,odata])
1353
- ri=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_imag(tim[k])),[1,12*nside**2,odata])
1354
- res=self.backend.bk_concat([res,self.backend.bk_complex(rr,ri)],0)
2630
+ res = self.backend.bk_reshape(
2631
+ self.backend.bk_sparse_dense_matmul(l_w_smooth, tim[0]),
2632
+ [1, 12 * nside**2, odata],
2633
+ )
2634
+
2635
+ for k in range(1, ndata):
2636
+ if tim.dtype == self.all_cbk_type:
2637
+ rr = self.backend.bk_reshape(
2638
+ self.backend.bk_sparse_dense_matmul(
2639
+ l_w_smooth, self.backend.bk_real(tim[k])
2640
+ ),
2641
+ [1, 12 * nside**2, odata],
2642
+ )
2643
+ ri = self.backend.bk_reshape(
2644
+ self.backend.bk_sparse_dense_matmul(
2645
+ l_w_smooth, self.backend.bk_imag(tim[k])
2646
+ ),
2647
+ [1, 12 * nside**2, odata],
2648
+ )
2649
+ res = self.backend.bk_concat(
2650
+ [res, self.backend.bk_complex(rr, ri)], 0
2651
+ )
1355
2652
  else:
1356
- res=self.backend.bk_concat([res,self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,tim[k]),[1,12*nside**2,odata])],0)
1357
-
1358
- if len(ishape)==axis+1:
1359
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2])
2653
+ res = self.backend.bk_concat(
2654
+ [
2655
+ res,
2656
+ self.backend.bk_reshape(
2657
+ self.backend.bk_sparse_dense_matmul(
2658
+ l_w_smooth, tim[k]
2659
+ ),
2660
+ [1, 12 * nside**2, odata],
2661
+ ),
2662
+ ],
2663
+ 0,
2664
+ )
2665
+
2666
+ if len(ishape) == axis + 1:
2667
+ return self.backend.bk_reshape(
2668
+ res, ishape[0:axis] + [12 * nside**2]
2669
+ )
1360
2670
  else:
1361
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2]+ishape[axis+1:])
1362
-
1363
-
1364
- return(res)
1365
-
2671
+ return self.backend.bk_reshape(
2672
+ res, ishape[0:axis] + [12 * nside**2] + ishape[axis + 1 :]
2673
+ )
2674
+
2675
+ return res
2676
+
1366
2677
  # ---------------------------------------------−---------
1367
2678
  def get_kernel_size(self):
1368
- return(self.KERNELSZ)
1369
-
2679
+ return self.KERNELSZ
2680
+
1370
2681
  # ---------------------------------------------−---------
1371
2682
  def get_nb_orient(self):
1372
- return(self.NORIENT)
1373
-
2683
+ return self.NORIENT
2684
+
1374
2685
  # ---------------------------------------------−---------
1375
- def get_ww(self,nside=1):
1376
- return(self.ww_Real[nside],self.ww_Imag[nside])
1377
-
2686
+ def get_ww(self, nside=1):
2687
+ return (self.ww_Real[nside], self.ww_Imag[nside])
2688
+
1378
2689
  # ---------------------------------------------−---------
1379
2690
  def plot_ww(self):
1380
- c,s=self.get_ww()
2691
+ c, s = self.get_ww()
1381
2692
  import matplotlib.pyplot as plt
1382
- plt.figure(figsize=(16,6))
1383
- npt=int(np.sqrt(c.shape[0]))
2693
+
2694
+ plt.figure(figsize=(16, 6))
2695
+ npt = int(np.sqrt(c.shape[0]))
1384
2696
  for i in range(c.shape[1]):
1385
- plt.subplot(2,c.shape[1],1+i)
1386
- plt.imshow(c[:,i].reshape(npt,npt),cmap='jet',vmin=-c.max(),vmax=c.max())
1387
- plt.subplot(2,c.shape[1],1+i+c.shape[1])
1388
- plt.imshow(s[:,i].reshape(npt,npt),cmap='jet',vmin=-c.max(),vmax=c.max())
2697
+ plt.subplot(2, c.shape[1], 1 + i)
2698
+ plt.imshow(
2699
+ c[:, i].reshape(npt, npt), cmap="jet", vmin=-c.max(), vmax=c.max()
2700
+ )
2701
+ plt.subplot(2, c.shape[1], 1 + i + c.shape[1])
2702
+ plt.imshow(
2703
+ s[:, i].reshape(npt, npt), cmap="jet", vmin=-c.max(), vmax=c.max()
2704
+ )
1389
2705
  sys.stdout.flush()
1390
2706
  plt.show()
1391
-
1392
-
1393
-
1394
-
1395
-