foscat 3.0.8__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -1,762 +1,1468 @@
1
- import numpy as np
1
+ import os
2
+ import sys
3
+
2
4
  import healpy as hp
3
- import os, sys
4
- import foscat.backend as bk
5
+ import numpy as np
5
6
  from scipy.interpolate import griddata
6
7
 
7
- TMPFILE_VERSION='V2_6'
8
+ import foscat.backend as bk
8
9
 
9
- class FoCUS:
10
- def __init__(self,
11
- NORIENT=4,
12
- LAMBDA=1.2,
13
- KERNELSZ=3,
14
- slope=1.0,
15
- all_type='float64',
16
- nstep_max=16,
17
- padding='SAME',
18
- gpupos=0,
19
- mask_thres=None,
20
- mask_norm=False,
21
- OSTEP=0,
22
- isMPI=False,
23
- TEMPLATE_PATH='data',
24
- BACKEND='tensorflow',
25
- use_2D=False,
26
- return_data=False,
27
- JmaxDelta=0,
28
- DODIV=False,
29
- InitWave=None,
30
- mpi_size=1,
31
- mpi_rank=0):
10
+ TMPFILE_VERSION = "V4_0"
32
11
 
12
+
13
+ class FoCUS:
14
+ def __init__(
15
+ self,
16
+ NORIENT=4,
17
+ LAMBDA=1.2,
18
+ KERNELSZ=3,
19
+ slope=1.0,
20
+ all_type="float64",
21
+ nstep_max=16,
22
+ padding="SAME",
23
+ gpupos=0,
24
+ mask_thres=None,
25
+ mask_norm=False,
26
+ OSTEP=0,
27
+ isMPI=False,
28
+ TEMPLATE_PATH="data",
29
+ BACKEND="tensorflow",
30
+ use_2D=False,
31
+ use_1D=False,
32
+ return_data=False,
33
+ JmaxDelta=0,
34
+ DODIV=False,
35
+ InitWave=None,
36
+ silent=False,
37
+ mpi_size=1,
38
+ mpi_rank=0,
39
+ ):
40
+
41
+ self.__version__ = "3.6.0"
33
42
  # P00 coeff for normalization for scat_cov
34
- self.TMPFILE_VERSION=TMPFILE_VERSION
43
+ self.TMPFILE_VERSION = TMPFILE_VERSION
35
44
  self.P1_dic = None
36
45
  self.P2_dic = None
37
- self.isMPI=isMPI
46
+ self.isMPI = isMPI
38
47
  self.mask_thres = mask_thres
39
48
  self.mask_norm = mask_norm
40
- self.InitWave=InitWave
41
-
42
- self.mpi_size=mpi_size
43
- self.mpi_rank=mpi_rank
44
- self.return_data=return_data
45
-
46
- print('================================================')
47
- print(' START FOSCAT CONFIGURATION')
48
- print('================================================')
49
- sys.stdout.flush()
50
-
51
- self.TEMPLATE_PATH=TEMPLATE_PATH
52
- if os.path.exists(self.TEMPLATE_PATH)==False:
53
- print('The directory %s to store temporary information for FoCUS does not exist: Try to create it'%(self.TEMPLATE_PATH))
49
+ self.InitWave = InitWave
50
+
51
+ self.mpi_size = mpi_size
52
+ self.mpi_rank = mpi_rank
53
+ self.return_data = return_data
54
+ self.silent = silent
55
+
56
+ if not self.silent:
57
+ print("================================================")
58
+ print(" START FOSCAT CONFIGURATION")
59
+ print("================================================")
60
+ sys.stdout.flush()
61
+
62
+ self.TEMPLATE_PATH = TEMPLATE_PATH
63
+ if not os.path.exists(self.TEMPLATE_PATH):
64
+ if not self.silent:
65
+ print(
66
+ "The directory %s to store temporary information for FoCUS does not exist: Try to create it"
67
+ % (self.TEMPLATE_PATH)
68
+ )
54
69
  try:
55
- os.system('mkdir -p %s'%(self.TEMPLATE_PATH))
56
- print('The directory %s is created')
70
+ os.system("mkdir -p %s" % (self.TEMPLATE_PATH))
71
+ if not self.silent:
72
+ print("The directory %s is created")
57
73
  except:
58
- print('Impossible to create the directory %s'%(self.TEMPLATE_PATH))
59
- exit(0)
60
-
61
- self.number_of_loss=0
62
-
63
- self.history=np.zeros([10])
64
- self.nlog=0
65
- self.padding=padding
66
-
67
- if OSTEP!=0:
68
- print('OPTION option is deprecated after version 2.0.6. Please use Jmax option')
69
- JmaxDelta=OSTEP
74
+ if not self.silent:
75
+ print(
76
+ "Impossible to create the directory %s" % (self.TEMPLATE_PATH)
77
+ )
78
+ return None
79
+
80
+ self.number_of_loss = 0
81
+
82
+ self.history = np.zeros([10])
83
+ self.nlog = 0
84
+ self.padding = padding
85
+
86
+ if OSTEP != 0:
87
+ if not self.silent:
88
+ print(
89
+ "OPTION option is deprecated after version 2.0.6. Please use Jmax option"
90
+ )
91
+ JmaxDelta = OSTEP
70
92
  else:
71
- OSTEP=JmaxDelta
72
-
73
- if JmaxDelta<-1:
74
- print('Warning : Jmax can not be smaller than -1')
75
- exit(0)
76
-
77
- self.OSTEP=JmaxDelta
78
- self.use_2D=use_2D
79
-
93
+ OSTEP = JmaxDelta
94
+
95
+ if JmaxDelta < -1:
96
+ if not self.silent:
97
+ print("Warning : Jmax can not be smaller than -1")
98
+ return None
99
+
100
+ self.OSTEP = JmaxDelta
101
+ self.use_2D = use_2D
102
+ self.use_1D = use_1D
103
+
80
104
  if isMPI:
81
105
  from mpi4py import MPI
82
106
 
83
- self.comm= MPI.COMM_WORLD
84
- if all_type=='float32':
85
- self.MPI_ALL_TYPE=MPI.FLOAT
107
+ self.comm = MPI.COMM_WORLD
108
+ if all_type == "float32":
109
+ self.MPI_ALL_TYPE = MPI.FLOAT
86
110
  else:
87
- self.MPI_ALL_TYPE=MPI.DOUBLE
111
+ self.MPI_ALL_TYPE = MPI.DOUBLE
88
112
  else:
89
- self.MPI_ALL_TYPE=None
90
-
91
- self.all_type=all_type
92
- self.BACKEND=BACKEND
93
- self.backend=bk.foscat_backend(BACKEND,
94
- all_type=all_type,
95
- mpi_rank=mpi_rank,
96
- gpupos=gpupos)
97
-
98
- self.all_bk_type=self.backend.all_bk_type
99
- self.all_cbk_type=self.backend.all_cbk_type
100
- self.gpulist=self.backend.gpulist
101
- self.ngpu=self.backend.ngpu
102
- self.rank=mpi_rank
103
-
104
- self.gpupos=(gpupos+mpi_rank)%self.backend.ngpu
105
-
106
- print('============================================================')
107
- print('== ==')
108
- print('== ==')
109
- print('== RUN ON GPU Rank %d : %s =='%(mpi_rank,self.gpulist[self.gpupos%self.ngpu]))
110
- print('== ==')
111
- print('== ==')
112
- print('============================================================')
113
- sys.stdout.flush()
114
-
115
- l_NORIENT=NORIENT
113
+ self.MPI_ALL_TYPE = None
114
+
115
+ self.all_type = all_type
116
+ self.BACKEND = BACKEND
117
+ self.backend = bk.foscat_backend(
118
+ BACKEND,
119
+ all_type=all_type,
120
+ mpi_rank=mpi_rank,
121
+ gpupos=gpupos,
122
+ silent=self.silent,
123
+ )
124
+
125
+ self.all_bk_type = self.backend.all_bk_type
126
+ self.all_cbk_type = self.backend.all_cbk_type
127
+ self.gpulist = self.backend.gpulist
128
+ self.ngpu = self.backend.ngpu
129
+ self.rank = mpi_rank
130
+
131
+ self.gpupos = (gpupos + mpi_rank) % self.backend.ngpu
132
+
133
+ if not self.silent:
134
+ print("============================================================")
135
+ print("== ==")
136
+ print("== ==")
137
+ print(
138
+ "== RUN ON GPU Rank %d : %s =="
139
+ % (mpi_rank, self.gpulist[self.gpupos % self.ngpu])
140
+ )
141
+ print("== ==")
142
+ print("== ==")
143
+ print("============================================================")
144
+ sys.stdout.flush()
145
+
146
+ l_NORIENT = NORIENT
116
147
  if DODIV:
117
- l_NORIENT=NORIENT+2
118
-
119
- self.NORIENT=l_NORIENT
120
- self.LAMBDA=LAMBDA
121
- self.slope=slope
122
-
123
- self.R_off=(KERNELSZ-1)//2
124
- if (self.R_off//2)*2<self.R_off:
125
- self.R_off+=1
126
-
127
- self.ww_Real = {}
128
- self.ww_Imag = {}
129
-
130
- wwc=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
131
- wws=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
132
-
133
- x=np.repeat(np.arange(KERNELSZ)-KERNELSZ//2,KERNELSZ).reshape(KERNELSZ,KERNELSZ)
134
- y=x.T
135
-
136
- if NORIENT==1:
137
- xx=(3/float(KERNELSZ))*LAMBDA*x
138
- yy=(3/float(KERNELSZ))*LAMBDA*y
139
-
140
- if KERNELSZ==5:
141
- #w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
142
- w_smooth=np.exp(-(xx**2+yy**2))
143
- tmp=np.exp(-2*(xx**2+yy**2))-0.25*np.exp(-0.5*(xx**2+yy**2))
148
+ l_NORIENT = NORIENT + 2
149
+
150
+ self.NORIENT = l_NORIENT
151
+ self.LAMBDA = LAMBDA
152
+ self.slope = slope
153
+
154
+ self.R_off = (KERNELSZ - 1) // 2
155
+ if (self.R_off // 2) * 2 < self.R_off:
156
+ self.R_off += 1
157
+
158
+ self.ww_Real = {}
159
+ self.ww_Imag = {}
160
+ self.ww_CNN_Transpose = {}
161
+ self.ww_CNN = {}
162
+ self.X_CNN = {}
163
+ self.Y_CNN = {}
164
+ self.Z_CNN = {}
165
+
166
+ wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
167
+ wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
168
+
169
+ x = np.repeat(np.arange(KERNELSZ) - KERNELSZ // 2, KERNELSZ).reshape(
170
+ KERNELSZ, KERNELSZ
171
+ )
172
+ y = x.T
173
+
174
+ if NORIENT == 1:
175
+ xx = (3 / float(KERNELSZ)) * LAMBDA * x
176
+ yy = (3 / float(KERNELSZ)) * LAMBDA * y
177
+
178
+ if KERNELSZ == 5:
179
+ # w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
180
+ w_smooth = np.exp(-(xx**2 + yy**2))
181
+ tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp(
182
+ -0.5 * (xx**2 + yy**2)
183
+ )
144
184
  else:
145
- w_smooth=np.exp(-0.5*(xx**2+yy**2))
146
- tmp=np.exp(-2*(xx**2+yy**2))-0.25*np.exp(-0.5*(xx**2+yy**2))
147
-
148
- wwc[:,0]=tmp.flatten()-tmp.mean()
149
- tmp=0*w_smooth
150
- wws[:,0]=tmp.flatten()
151
- sigma=np.sqrt((wwc[:,0]**2).mean())
152
- wwc[:,0]/=sigma
153
- wws[:,0]/=sigma
154
-
155
- w_smooth=w_smooth.flatten()
185
+ w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
186
+ tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp(
187
+ -0.5 * (xx**2 + yy**2)
188
+ )
189
+
190
+ wwc[:, 0] = tmp.flatten() - tmp.mean()
191
+ tmp = 0 * w_smooth
192
+ wws[:, 0] = tmp.flatten()
193
+ sigma = np.sqrt((wwc[:, 0] ** 2).mean())
194
+ wwc[:, 0] /= sigma
195
+ wws[:, 0] /= sigma
196
+
197
+ w_smooth = w_smooth.flatten()
156
198
  else:
157
199
  for i in range(NORIENT):
158
- a=i/float(NORIENT)*np.pi
159
- xx=(3/float(KERNELSZ))*LAMBDA*(x*np.cos(a)+y*np.sin(a))
160
- yy=(3/float(KERNELSZ))*LAMBDA*(x*np.sin(a)-y*np.cos(a))
200
+ a = i / float(NORIENT) * np.pi
201
+ xx = (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
202
+ yy = (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
161
203
 
162
- if KERNELSZ==5:
163
- #w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
164
- w_smooth=np.exp(-(xx**2+yy**2))
204
+ if KERNELSZ == 5:
205
+ # w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
206
+ w_smooth = np.exp(-(xx**2 + yy**2))
165
207
  else:
166
- w_smooth=np.exp(-0.5*(xx**2+yy**2))
167
- tmp1=np.cos(yy*np.pi)*w_smooth
168
- tmp2=np.sin(yy*np.pi)*w_smooth
169
-
170
- wwc[:,i]=tmp1.flatten()-tmp1.mean()
171
- wws[:,i]=tmp2.flatten()-tmp2.mean()
172
- sigma=np.sqrt((wwc[:,i]**2).mean())
173
- wwc[:,i]/=sigma
174
- wws[:,i]/=sigma
175
-
176
- if DODIV and i==0:
177
- r=(xx**2+yy**2)
178
- theta=np.arctan2(yy,xx)
179
- theta[KERNELSZ//2,KERNELSZ//2]=0.0
180
- tmp1=r*np.cos(2*theta)*w_smooth
181
- tmp2=r*np.sin(2*theta)*w_smooth
182
-
183
- wwc[:,NORIENT]=tmp1.flatten()-tmp1.mean()
184
- wws[:,NORIENT]=tmp2.flatten()-tmp2.mean()
185
- sigma=np.sqrt((wwc[:,NORIENT]**2).mean())
186
-
187
- wwc[:,NORIENT]/=sigma
188
- wws[:,NORIENT]/=sigma
189
- tmp1=r*np.cos(2*theta+np.pi)
190
- tmp2=r*np.sin(2*theta+np.pi)
191
-
192
- wwc[:,NORIENT+1]=tmp1.flatten()-tmp1.mean()
193
- wws[:,NORIENT+1]=tmp2.flatten()-tmp2.mean()
194
- sigma=np.sqrt((wwc[:,NORIENT+1]**2).mean())
195
- wwc[:,NORIENT+1]/=sigma
196
- wws[:,NORIENT+1]/=sigma
197
-
198
-
199
- w_smooth=w_smooth.flatten()
200
-
201
- self.KERNELSZ=KERNELSZ
202
-
203
- self.Idx_Neighbours={}
204
-
205
- if not self.use_2D:
208
+ w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
209
+ tmp1 = np.cos(yy * np.pi) * w_smooth
210
+ tmp2 = np.sin(yy * np.pi) * w_smooth
211
+
212
+ wwc[:, i] = tmp1.flatten() - tmp1.mean()
213
+ wws[:, i] = tmp2.flatten() - tmp2.mean()
214
+ sigma = np.sqrt((wwc[:, i] ** 2).mean())
215
+ wwc[:, i] /= sigma
216
+ wws[:, i] /= sigma
217
+
218
+ if DODIV and i == 0:
219
+ r = xx**2 + yy**2
220
+ theta = np.arctan2(yy, xx)
221
+ theta[KERNELSZ // 2, KERNELSZ // 2] = 0.0
222
+ tmp1 = r * np.cos(2 * theta) * w_smooth
223
+ tmp2 = r * np.sin(2 * theta) * w_smooth
224
+
225
+ wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
226
+ wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
227
+ sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
228
+
229
+ wwc[:, NORIENT] /= sigma
230
+ wws[:, NORIENT] /= sigma
231
+ tmp1 = r * np.cos(2 * theta + np.pi)
232
+ tmp2 = r * np.sin(2 * theta + np.pi)
233
+
234
+ wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
235
+ wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
236
+ sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
237
+ wwc[:, NORIENT + 1] /= sigma
238
+ wws[:, NORIENT + 1] /= sigma
239
+
240
+ w_smooth = w_smooth.flatten()
241
+ if self.use_1D:
242
+ KERNELSZ = 5
243
+
244
+ self.KERNELSZ = KERNELSZ
245
+
246
+ self.Idx_Neighbours = {}
247
+
248
+ if not self.use_2D and not self.use_1D:
206
249
  self.w_smooth = {}
207
250
  for i in range(nstep_max):
208
- lout=(2**i)
209
- self.ww_Real[lout]=None
251
+ lout = 2**i
252
+ self.ww_Real[lout] = None
253
+
254
+ for i in range(1, 6):
255
+ lout = 2**i
256
+ if not self.silent:
257
+ print("Init Wave ", lout)
210
258
 
211
- for i in range(1,6):
212
- lout=(2**i)
213
- print('Init Wave ',lout)
214
-
215
259
  if self.InitWave is None:
216
- wr,wi,ws,widx=self.init_index(lout)
260
+ wr, wi, ws, widx = self.init_index(lout)
217
261
  else:
218
- wr,wi,ws,widx=self.InitWave(self,lout)
219
-
220
- self.Idx_Neighbours[lout]=1 #self.backend.constant(widx)
221
- self.ww_Real[lout]=wr
222
- self.ww_Imag[lout]=wi
223
- self.w_smooth[lout]=ws
224
- else:
225
- self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
226
- self.ww_RealT={}
227
- self.ww_ImagT={}
228
- self.ww_SmoothT={}
262
+ wr, wi, ws, widx = self.InitWave(self, lout)
263
+
264
+ self.Idx_Neighbours[lout] = 1 # self.backend.constant(widx)
265
+ self.ww_Real[lout] = wr
266
+ self.ww_Imag[lout] = wi
267
+ self.w_smooth[lout] = ws
268
+ elif self.use_1D:
269
+ self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
270
+ self.ww_RealT = {}
271
+ self.ww_ImagT = {}
272
+ self.ww_SmoothT = {}
273
+ if KERNELSZ == 5:
274
+ xx = np.arange(5) - 2
275
+ w = np.exp(-0.25 * (xx) ** 2)
276
+ c = w * np.cos((xx) * np.pi / 2)
277
+ s = w * np.sin((xx) * np.pi / 2)
278
+
279
+ w = w / np.sum(w)
280
+ c = c - np.mean(c)
281
+ s = s - np.mean(s)
282
+ r = np.sum(np.sqrt(c * c + s * s))
283
+ c = c / r
284
+ s = s / r
285
+ self.ww_RealT[1] = self.backend.constant(
286
+ np.array(c).reshape(xx.shape[0], 1, 1)
287
+ )
288
+ self.ww_ImagT[1] = self.backend.constant(
289
+ np.array(s).reshape(xx.shape[0], 1, 1)
290
+ )
291
+ self.ww_SmoothT[1] = self.backend.constant(
292
+ np.array(w).reshape(xx.shape[0], 1, 1)
293
+ )
229
294
 
230
- self.ww_SmoothT[1] = self.backend.constant(self.w_smooth.reshape(KERNELSZ,KERNELSZ,1,1))
231
- www=np.zeros([KERNELSZ,KERNELSZ,NORIENT,NORIENT],dtype=self.all_type)
295
+ else:
296
+ self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
297
+ self.ww_RealT = {}
298
+ self.ww_ImagT = {}
299
+ self.ww_SmoothT = {}
300
+
301
+ self.ww_SmoothT[1] = self.backend.constant(
302
+ self.w_smooth.reshape(KERNELSZ, KERNELSZ, 1, 1)
303
+ )
304
+ www = np.zeros([KERNELSZ, KERNELSZ, NORIENT, NORIENT], dtype=self.all_type)
232
305
  for k in range(NORIENT):
233
- www[:,:,k,k]=self.w_smooth.reshape(KERNELSZ,KERNELSZ)
234
- self.ww_SmoothT[NORIENT] = self.backend.constant(www.reshape(KERNELSZ,KERNELSZ,NORIENT,NORIENT))
235
- self.ww_RealT[1]=self.backend.constant(self.backend.bk_reshape(wwc.astype(self.all_type),[KERNELSZ,KERNELSZ,1,NORIENT]))
236
- self.ww_ImagT[1]=self.backend.constant(self.backend.bk_reshape(wws.astype(self.all_type),[KERNELSZ,KERNELSZ,1,NORIENT]))
306
+ www[:, :, k, k] = self.w_smooth.reshape(KERNELSZ, KERNELSZ)
307
+ self.ww_SmoothT[NORIENT] = self.backend.constant(
308
+ www.reshape(KERNELSZ, KERNELSZ, NORIENT, NORIENT)
309
+ )
310
+ self.ww_RealT[1] = self.backend.constant(
311
+ self.backend.bk_reshape(
312
+ wwc.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
313
+ )
314
+ )
315
+ self.ww_ImagT[1] = self.backend.constant(
316
+ self.backend.bk_reshape(
317
+ wws.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
318
+ )
319
+ )
320
+
237
321
  def doorientw(x):
238
- y=np.zeros([KERNELSZ,KERNELSZ,NORIENT,NORIENT*NORIENT],dtype=self.all_type)
322
+ y = np.zeros(
323
+ [KERNELSZ, KERNELSZ, NORIENT, NORIENT * NORIENT],
324
+ dtype=self.all_type,
325
+ )
239
326
  for k in range(NORIENT):
240
- y[:,:,k,k*NORIENT:k*NORIENT+NORIENT]=x.reshape(KERNELSZ,KERNELSZ,NORIENT)
327
+ y[:, :, k, k * NORIENT : k * NORIENT + NORIENT] = x.reshape(
328
+ KERNELSZ, KERNELSZ, NORIENT
329
+ )
241
330
  return y
242
- self.ww_RealT[NORIENT]=self.backend.constant(doorientw(wwc.astype(self.all_type)))
243
- self.ww_ImagT[NORIENT]=self.backend.constant(doorientw(wws.astype(self.all_type)))
244
- self.pix_interp_val={}
245
- self.weight_interp_val={}
246
- self.ring2nest={}
247
- self.nest2R={}
248
- self.nest2R1={}
249
- self.nest2R2={}
250
- self.nest2R3={}
251
- self.nest2R4={}
252
- self.inv_nest2R={}
253
- self.remove_border={}
254
-
255
- self.ampnorm={}
256
-
331
+
332
+ self.ww_RealT[NORIENT] = self.backend.constant(
333
+ doorientw(wwc.astype(self.all_type))
334
+ )
335
+ self.ww_ImagT[NORIENT] = self.backend.constant(
336
+ doorientw(wws.astype(self.all_type))
337
+ )
338
+ self.pix_interp_val = {}
339
+ self.weight_interp_val = {}
340
+ self.ring2nest = {}
341
+ self.nest2R = {}
342
+ self.nest2R1 = {}
343
+ self.nest2R2 = {}
344
+ self.nest2R3 = {}
345
+ self.nest2R4 = {}
346
+ self.inv_nest2R = {}
347
+ self.remove_border = {}
348
+
349
+ self.ampnorm = {}
350
+
257
351
  for i in range(nstep_max):
258
- lout=(2**i)
259
- self.pix_interp_val[lout]={}
260
- self.weight_interp_val[lout]={}
352
+ lout = 2**i
353
+ self.pix_interp_val[lout] = {}
354
+ self.weight_interp_val[lout] = {}
261
355
  for j in range(nstep_max):
262
- lout2=(2**j)
263
- self.pix_interp_val[lout][lout2]=None
264
- self.weight_interp_val[lout][lout2]=None
265
- self.ring2nest[lout]=None
266
- self.Idx_Neighbours[lout]=None
267
- self.nest2R[lout]=None
268
- self.nest2R1[lout]=None
269
- self.nest2R2[lout]=None
270
- self.nest2R3[lout]=None
271
- self.nest2R4[lout]=None
272
- self.inv_nest2R[lout]=None
273
- self.remove_border[lout]=None
274
-
275
- self.loss={}
356
+ lout2 = 2**j
357
+ self.pix_interp_val[lout][lout2] = None
358
+ self.weight_interp_val[lout][lout2] = None
359
+ self.ring2nest[lout] = None
360
+ self.Idx_Neighbours[lout] = None
361
+ self.nest2R[lout] = None
362
+ self.nest2R1[lout] = None
363
+ self.nest2R2[lout] = None
364
+ self.nest2R3[lout] = None
365
+ self.nest2R4[lout] = None
366
+ self.inv_nest2R[lout] = None
367
+ self.remove_border[lout] = None
368
+ self.ww_CNN_Transpose[lout] = None
369
+ self.ww_CNN[lout] = None
370
+ self.X_CNN[lout] = None
371
+ self.Y_CNN[lout] = None
372
+ self.Z_CNN[lout] = None
373
+
374
+ self.loss = {}
276
375
 
277
376
  def get_type(self):
278
377
  return self.all_type
279
378
 
280
379
  def get_mpi_type(self):
281
380
  return self.MPI_ALL_TYPE
282
-
381
+
283
382
  # ---------------------------------------------−---------
284
383
  # -- COMPUTE 3X3 INDEX FOR HEALPIX WORK --
285
384
  # ---------------------------------------------−---------
286
- def conv_to_FoCUS(self,x,axis=0):
287
- if self.use_2D and isinstance(x,np.ndarray):
288
- return(self.to_R(x,axis,chans=self.chans))
385
+ def conv_to_FoCUS(self, x, axis=0):
386
+ if self.use_2D and isinstance(x, np.ndarray):
387
+ return self.to_R(x, axis, chans=self.chans)
289
388
  return x
290
389
 
291
- def diffang(self,a,b):
292
- return np.arctan2(np.sin(a)-np.sin(b),np.cos(a)-np.cos(b))
293
-
294
- def corr_idx_wXX(self,x,y):
295
- idx=np.where(x==-1)[0]
296
- res=x
297
- res[idx]=y[idx]
298
- return(res)
299
-
390
+ def diffang(self, a, b):
391
+ return np.arctan2(np.sin(a) - np.sin(b), np.cos(a) - np.cos(b))
392
+
393
+ def corr_idx_wXX(self, x, y):
394
+ idx = np.where(x == -1)[0]
395
+ res = x
396
+ res[idx] = y[idx]
397
+ return res
398
+
399
+ # ---------------------------------------------−---------
400
+ # make the CNN working : index reporjection of the kernel on healpix
401
+
402
+ def calc_indices_convol(self, nside, kernel, rotation=None):
403
+ to, po = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
404
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside * nside), nest=True)
405
+
406
+ idx = np.argsort((x - 1.0) ** 2 + y**2 + z**2)[0:kernel]
407
+ x0, y0, z0 = hp.pix2vec(nside, idx[0], nest=True)
408
+ t0, p0 = hp.pix2ang(nside, idx[0], nest=True)
409
+
410
+ idx = np.argsort((x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2)[0:kernel]
411
+ im = np.ones([12 * nside**2]) * -1
412
+ im[idx] = np.arange(len(idx))
413
+
414
+ xc, yc, zc = hp.pix2vec(nside, idx, nest=True)
415
+
416
+ xc -= x0
417
+ yc -= y0
418
+ zc -= z0
419
+
420
+ vec = np.concatenate(
421
+ [np.expand_dims(x, -1), np.expand_dims(y, -1), np.expand_dims(z, -1)], 1
422
+ )
423
+
424
+ indices = np.zeros([12 * nside**2 * 250, 2], dtype="int")
425
+ weights = np.zeros([12 * nside**2 * 250])
426
+ nn = 0
427
+ for k in range(12 * nside * nside):
428
+ if k % (nside * nside) == nside * nside - 1:
429
+ print(
430
+ "Nside=%d KenelSZ=%d %.2f%%"
431
+ % (nside, kernel, k / (12 * nside**2) * 100)
432
+ )
433
+ if nside < 4:
434
+ idx2 = np.arange(12 * nside**2)
435
+ else:
436
+ idx2 = hp.query_disc(
437
+ nside, vec[k], np.pi / nside, inclusive=True, nest=True
438
+ )
439
+ t2, p2 = hp.pix2ang(nside, idx2, nest=True)
440
+ if rotation is None:
441
+ rot = [po[k] / np.pi * 180.0, (t0 - to[k]) / np.pi * 180.0]
442
+ else:
443
+ rot = [po[k] / np.pi * 180.0, (t0 - to[k]) / np.pi * 180.0, rotation[k]]
444
+
445
+ r = hp.Rotator(rot=rot)
446
+ t2, p2 = r(t2, p2)
447
+
448
+ ii, ww = hp.get_interp_weights(nside, t2, p2, nest=True)
449
+
450
+ ii = im[ii]
451
+
452
+ for l_rotation in range(4):
453
+ iii = np.where(ii[l_rotation] != -1)[0]
454
+ if len(iii) > 0:
455
+ indices[nn : nn + len(iii), 1] = idx2[iii]
456
+ indices[nn : nn + len(iii), 0] = k * kernel + ii[l_rotation, iii]
457
+ weights[nn : nn + len(iii)] = ww[l_rotation, iii]
458
+ nn += len(iii)
459
+
460
+ indices = indices[0:nn]
461
+ weights = weights[0:nn]
462
+ if k % (nside * nside) == nside * nside - 1:
463
+ print(
464
+ "Nside=%d KenelSZ=%d Total Number of value=%d Ratio of the matrix %.2g%%"
465
+ % (
466
+ nside,
467
+ kernel,
468
+ nn,
469
+ 100 * nn / (kernel * 12 * nside**2 * 12 * nside**2),
470
+ )
471
+ )
472
+ return indices, weights, xc, yc, zc
473
+
474
+ # ---------------------------------------------−---------
475
+ def calc_orientation(self, im): # im is [Ndata,12*Nside**2]
476
+ nside = int(np.sqrt(im.shape[1] // 12))
477
+ l_kernel = self.KERNELSZ * self.KERNELSZ
478
+ norient = 32
479
+ w = np.zeros([l_kernel, 1, 2 * norient])
480
+ ca = np.cos(np.arange(norient) / norient * np.pi)
481
+ sa = np.sin(np.arange(norient) / norient * np.pi)
482
+ stat = np.zeros([12 * nside**2, norient])
483
+
484
+ if self.ww_CNN[nside] is None:
485
+ self.init_CNN_index(nside, transpose=False)
486
+
487
+ y = self.Y_CNN[nside]
488
+ z = self.Z_CNN[nside]
489
+
490
+ for k in range(norient):
491
+ w[:, 0, k] = np.exp(-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)) * np.cos(
492
+ nside * (y * ca[k] + z * sa[k]) * np.pi / 2
493
+ )
494
+ w[:, 0, k + norient] = np.exp(
495
+ -0.5 * nside**2 * ((y) ** 2 + (z) ** 2)
496
+ ) * np.sin(nside * (y * ca[k] + z * sa[k]) * np.pi / 2)
497
+ w[:, 0, k] = w[:, 0, k] - np.mean(w[:, 0, k])
498
+ w[:, 0, k + norient] = w[:, 0, k] - np.mean(w[:, 0, k + norient])
499
+
500
+ for k in range(im.shape[0]):
501
+ tmp = im[k].reshape(12 * nside**2, 1)
502
+ im2 = self.healpix_layer(tmp, w)
503
+ stat = stat + im2[:, 0:norient] ** 2 + im2[:, norient:] ** 2
504
+
505
+ rotation = (np.argmax(stat, 1)).astype("float") / 32.0 * 180.0
506
+
507
+ indices, weights, x, y, z = self.calc_indices_convol(
508
+ nside, 9, rotation=rotation
509
+ )
510
+
511
+ return indices, weights
512
+
513
+ def init_CNN_index(self, nside, transpose=False):
514
+ l_kernel = int(self.KERNELSZ * self.KERNELSZ)
515
+ try:
516
+ indices = np.load(
517
+ "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
518
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
519
+ )
520
+ weights = np.load(
521
+ "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
522
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
523
+ )
524
+ xc = np.load(
525
+ "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
526
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
527
+ )
528
+ yc = np.load(
529
+ "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
530
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
531
+ )
532
+ zc = np.load(
533
+ "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
534
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
535
+ )
536
+ except:
537
+ indices, weights, xc, yc, zc = self.calc_indices_convol(nside, l_kernel)
538
+ np.save(
539
+ "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
540
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
541
+ indices,
542
+ )
543
+ np.save(
544
+ "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
545
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
546
+ weights,
547
+ )
548
+ np.save(
549
+ "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
550
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
551
+ xc,
552
+ )
553
+ np.save(
554
+ "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
555
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
556
+ yc,
557
+ )
558
+ np.save(
559
+ "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
560
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
561
+ zc,
562
+ )
563
+ if not self.silent:
564
+ print(
565
+ "Write %s/FOSCAT_%s_W%d_%d_%d_CNNV2.npy"
566
+ % (
567
+ self.TEMPLATE_PATH,
568
+ TMPFILE_VERSION,
569
+ l_kernel,
570
+ self.NORIENT,
571
+ nside,
572
+ )
573
+ )
574
+
575
+ self.X_CNN[nside] = xc
576
+ self.Y_CNN[nside] = yc
577
+ self.Z_CNN[nside] = zc
578
+ self.ww_CNN[nside] = self.backend.bk_SparseTensor(
579
+ indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
580
+ )
581
+
582
+ # ---------------------------------------------−---------
583
+ def healpix_layer_coord(self, im, axis=0):
584
+ nside = int(np.sqrt(im.shape[axis] // 12))
585
+ if self.ww_CNN[nside] is None:
586
+ self.init_CNN_index(nside)
587
+ return self.X_CNN[nside], self.Y_CNN[nside], self.Z_CNN[nside]
588
+
589
+ # ---------------------------------------------−---------
590
+ def healpix_layer_transpose(self, im, ww, indices=None, weights=None, axis=0):
591
+ nside = int(np.sqrt(im.shape[axis] // 12))
592
+
593
+ if im.shape[1 + axis] != ww.shape[1]:
594
+ if not self.silent:
595
+ print("Weights channels should be equal to the input image channels")
596
+ return -1
597
+ if axis == 1:
598
+ results = []
599
+
600
+ for k in range(im.shape[0]):
601
+
602
+ tmp = self.healpix_layer(
603
+ im[k], ww, indices=indices, weights=weights, axis=0
604
+ )
605
+ tmp = self.backend.bk_reshape(
606
+ self.up_grade(tmp, 2 * nside), [12 * 4 * nside**2, ww.shape[2]]
607
+ )
608
+
609
+ results.append(tmp)
610
+
611
+ return self.backend.bk_stack(results, axis=0)
612
+ else:
613
+ tmp = self.healpix_layer(
614
+ im, ww, indices=indices, weights=weights, axis=axis
615
+ )
616
+
617
+ return self.up_grade(tmp, 2 * nside)
618
+
619
+ # ---------------------------------------------−---------
620
+ # ---------------------------------------------−---------
621
+ def healpix_layer(self, im, ww, indices=None, weights=None, axis=0):
622
+ nside = int(np.sqrt(im.shape[axis] // 12))
623
+ l_kernel = self.KERNELSZ * self.KERNELSZ
624
+
625
+ if im.shape[1 + axis] != ww.shape[1]:
626
+ if not self.silent:
627
+ print("Weights channels should be equal to the input image channels")
628
+ return -1
629
+
630
+ if indices is None:
631
+ if self.ww_CNN[nside] is None:
632
+ self.init_CNN_index(nside, transpose=False)
633
+ mat = self.ww_CNN[nside]
634
+ else:
635
+ if weights is None:
636
+ print(
637
+ "healpix_layer : If indices is not none weights should be specify"
638
+ )
639
+ return 0
640
+
641
+ mat = self.backend.bk_SparseTensor(
642
+ indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
643
+ )
644
+
645
+ if axis == 1:
646
+ results = []
647
+
648
+ for k in range(im.shape[0]):
649
+
650
+ tmp = self.backend.bk_sparse_dense_matmul(mat, im[k])
651
+
652
+ density = self.backend.bk_reshape(
653
+ tmp, [12 * nside * nside, l_kernel * im.shape[1 + axis]]
654
+ )
655
+
656
+ density = self.backend.bk_matmul(
657
+ density,
658
+ self.backend.bk_reshape(
659
+ ww, [l_kernel * im.shape[1 + axis], ww.shape[2]]
660
+ ),
661
+ )
662
+
663
+ results.append(
664
+ self.backend.bk_reshape(density, [12 * nside**2, ww.shape[2]])
665
+ )
666
+
667
+ return self.backend.bk_stack(results, axis=0)
668
+ else:
669
+ tmp = self.backend.bk_sparse_dense_matmul(mat, im)
670
+
671
+ density = self.backend.bk_reshape(
672
+ tmp, [12 * nside * nside, l_kernel * im.shape[1]]
673
+ )
674
+
675
+ return self.backend.bk_matmul(
676
+ density,
677
+ self.backend.bk_reshape(ww, [l_kernel * im.shape[1], ww.shape[2]]),
678
+ )
679
+
680
+ # ---------------------------------------------−---------
681
+
300
682
  # ---------------------------------------------−---------
301
683
  def get_rank(self):
302
- return(self.rank)
684
+ return self.rank
685
+
303
686
  # ---------------------------------------------−---------
304
687
  def get_size(self):
305
- return(self.size)
306
-
688
+ return self.size
689
+
307
690
  # ---------------------------------------------−---------
308
691
  def barrier(self):
309
692
  if self.isMPI:
310
693
  self.comm.Barrier()
311
-
694
+
312
695
  # ---------------------------------------------−---------
313
- def toring(self,image,axis=0):
314
- lout=int(np.sqrt(image.shape[axis]//12))
315
-
696
+ def toring(self, image, axis=0):
697
+ lout = int(np.sqrt(image.shape[axis] // 12))
698
+
316
699
  if self.ring2nest[lout] is None:
317
- self.ring2nest[lout]=hp.ring2nest(lout,np.arange(12*lout**2))
318
-
700
+ self.ring2nest[lout] = hp.ring2nest(lout, np.arange(12 * lout**2))
701
+
319
702
  return image.numpy()[self.ring2nest[lout]]
320
703
 
321
- #--------------------------------------------------------
322
- def ud_grade(self,im,j,axis=0):
323
- rim=im
704
+ # --------------------------------------------------------
705
+ def ud_grade(self, im, j, axis=0):
706
+ rim = im
324
707
  for k in range(j):
325
- rim=self.smooth(rim,axis=axis)
326
- rim=self.ud_grade_2(rim,axis=axis)
708
+ rim = self.smooth(rim, axis=axis)
709
+ rim = self.ud_grade_2(rim, axis=axis)
327
710
  return rim
328
-
329
- #--------------------------------------------------------
330
- def ud_grade_2(self,im,axis=0):
331
-
711
+
712
+ # --------------------------------------------------------
713
+ def ud_grade_2(self, im, axis=0):
714
+
332
715
  if self.use_2D:
333
- ishape=list(im.shape)
334
- if len(ishape)<axis+2:
335
- print('Use of 2D scat with data that has less than 2D')
336
- exit(0)
337
-
338
- npix=im.shape[axis]
339
- npiy=im.shape[axis+1]
340
- odata=1
341
- if len(ishape)>axis+2:
342
- for k in range(axis+2,len(ishape)):
343
- odata=odata*ishape[k]
344
-
345
- ndata=1
716
+ ishape = list(im.shape)
717
+ if len(ishape) < axis + 2:
718
+ if not self.silent:
719
+ print("Use of 2D scat with data that has less than 2D")
720
+ return None
721
+
722
+ npix = im.shape[axis]
723
+ npiy = im.shape[axis + 1]
724
+ odata = 1
725
+ if len(ishape) > axis + 2:
726
+ for k in range(axis + 2, len(ishape)):
727
+ odata = odata * ishape[k]
728
+
729
+ ndata = 1
346
730
  for k in range(axis):
347
- ndata=ndata*ishape[k]
731
+ ndata = ndata * ishape[k]
348
732
 
349
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
350
- tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),0:2*(npiy//2),:],[ndata,npix//2,2,npiy//2,2,odata])
733
+ tim = self.backend.bk_reshape(
734
+ self.backend.bk_cast(im), [ndata, npix, npiy, odata]
735
+ )
736
+ tim = self.backend.bk_reshape(
737
+ tim[:, 0 : 2 * (npix // 2), 0 : 2 * (npiy // 2), :],
738
+ [ndata, npix // 2, 2, npiy // 2, 2, odata],
739
+ )
351
740
 
352
- res=self.backend.bk_reduce_mean(self.backend.bk_reduce_mean(tim,4),2)
353
-
354
- if axis==0:
355
- if len(ishape)==2:
356
- return self.backend.bk_reshape(res,[npix//2,npiy//2])
741
+ res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
742
+
743
+ if axis == 0:
744
+ if len(ishape) == 2:
745
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
357
746
  else:
358
- return self.backend.bk_reshape(res,[npix//2,npiy//2]+ishape[axis+2:])
747
+ return self.backend.bk_reshape(
748
+ res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
749
+ )
359
750
  else:
360
- if len(ishape)==axis+2:
361
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2])
751
+ if len(ishape) == axis + 2:
752
+ return self.backend.bk_reshape(
753
+ res, ishape[0:axis] + [npix // 2, npiy // 2]
754
+ )
755
+ else:
756
+ return self.backend.bk_reshape(
757
+ res,
758
+ ishape[0:axis] + [npix // 2, npiy // 2] + ishape[axis + 2 :],
759
+ )
760
+
761
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
762
+ elif self.use_1D:
763
+ ishape = list(im.shape)
764
+ if len(ishape) < axis + 1:
765
+ if not self.silent:
766
+ print("Use of 1D scat with data that has less than 1D")
767
+ return None
768
+
769
+ npix = im.shape[axis]
770
+ odata = 1
771
+ if len(ishape) > axis + 1:
772
+ for k in range(axis + 1, len(ishape)):
773
+ odata = odata * ishape[k]
774
+
775
+ ndata = 1
776
+ for k in range(axis):
777
+ ndata = ndata * ishape[k]
778
+
779
+ tim = self.backend.bk_reshape(
780
+ self.backend.bk_cast(im), [ndata, npix, odata]
781
+ )
782
+ tim = self.backend.bk_reshape(
783
+ tim[:, 0 : 2 * (npix // 2), :], [ndata, npix // 2, 2, odata]
784
+ )
785
+
786
+ res = self.backend.bk_reduce_mean(tim, 2)
787
+
788
+ if axis == 0:
789
+ if len(ishape) == 1:
790
+ return self.backend.bk_reshape(res, [npix // 2])
362
791
  else:
363
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2]+ishape[axis+2:])
364
-
365
- return self.backend.bk_reshape(res,[npix//2,npiy//2])
366
-
792
+ return self.backend.bk_reshape(
793
+ res, [npix // 2] + ishape[axis + 1 :]
794
+ )
795
+ else:
796
+ if len(ishape) == axis + 1:
797
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2])
798
+ else:
799
+ return self.backend.bk_reshape(
800
+ res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
801
+ )
802
+
803
+ return self.backend.bk_reshape(res, [npix // 2])
804
+
367
805
  else:
368
- shape=list(im.shape)
369
-
370
- lout=int(np.sqrt(shape[axis]//12))
371
- if im.__class__==np.zeros([0]).__class__:
372
- oshape=np.zeros([len(shape)+1],dtype='int')
373
- if axis>0:
374
- oshape[0:axis]=shape[0:axis]
375
- oshape[axis]=12*lout*lout//4
376
- oshape[axis+1]=4
377
- if len(shape)>axis:
378
- oshape[axis+2:]=shape[axis+1:]
806
+ shape = list(im.shape)
807
+
808
+ lout = int(np.sqrt(shape[axis] // 12))
809
+ if im.__class__ == np.zeros([0]).__class__:
810
+ oshape = np.zeros([len(shape) + 1], dtype="int")
811
+ if axis > 0:
812
+ oshape[0:axis] = shape[0:axis]
813
+ oshape[axis] = 12 * lout * lout // 4
814
+ oshape[axis + 1] = 4
815
+ if len(shape) > axis:
816
+ oshape[axis + 2 :] = shape[axis + 1 :]
379
817
  else:
380
- if axis>0:
381
- oshape=shape[0:axis]+[12*lout*lout//4,4]
818
+ if axis > 0:
819
+ oshape = shape[0:axis] + [12 * lout * lout // 4, 4]
382
820
  else:
383
- oshape=[12*lout*lout//4,4]
384
- if len(shape)>axis:
385
- oshape=oshape+shape[axis+1:]
386
-
387
- return(self.backend.bk_reduce_mean(self.backend.bk_reshape(im,oshape),axis=axis+1))
388
-
389
- #--------------------------------------------------------
390
- def up_grade(self,im,nout,axis=0,nouty=None):
391
-
821
+ oshape = [12 * lout * lout // 4, 4]
822
+ if len(shape) > axis:
823
+ oshape = oshape + shape[axis + 1 :]
824
+
825
+ return self.backend.bk_reduce_mean(
826
+ self.backend.bk_reshape(im, oshape), axis=axis + 1
827
+ )
828
+
829
+ # --------------------------------------------------------
830
+ def up_grade(self, im, nout, axis=0, nouty=None):
831
+
392
832
  if self.use_2D:
393
- ishape=list(im.shape)
394
- if len(ishape)<axis+2:
395
- print('Use of 2D scat with data that has less than 2D')
396
- exit(0)
397
-
833
+ ishape = list(im.shape)
834
+ if len(ishape) < axis + 2:
835
+ if not self.silent:
836
+ print("Use of 2D scat with data that has less than 2D")
837
+ return None
838
+
398
839
  if nouty is None:
399
- nouty=nout
400
- npix=im.shape[axis]
401
- npiy=im.shape[axis+1]
402
- odata=1
403
- if len(ishape)>axis+2:
404
- for k in range(axis+2,len(ishape)):
405
- odata=odata*ishape[k]
406
-
407
- ndata=1
840
+ nouty = nout
841
+
842
+ if ishape[axis] == nout and ishape[axis + 1] == nouty:
843
+ return im
844
+
845
+ npix = im.shape[axis]
846
+ npiy = im.shape[axis + 1]
847
+ odata = 1
848
+ if len(ishape) > axis + 2:
849
+ for k in range(axis + 2, len(ishape)):
850
+ odata = odata * ishape[k]
851
+
852
+ ndata = 1
408
853
  for k in range(axis):
409
- ndata=ndata*ishape[k]
854
+ ndata = ndata * ishape[k]
855
+
856
+ tim = self.backend.bk_reshape(
857
+ self.backend.bk_cast(im), [ndata, npix, npiy, odata]
858
+ )
410
859
 
411
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
860
+ res = self.backend.bk_resize_image(tim, [nout, nouty])
412
861
 
413
- res=self.backend.bk_resize_image(tim,[nout,nouty])
414
-
415
- if axis==0:
416
- if len(ishape)==2:
417
- return self.backend.bk_reshape(res,[nout,nouty])
862
+ if axis == 0:
863
+ if len(ishape) == 2:
864
+ return self.backend.bk_reshape(res, [nout, nouty])
418
865
  else:
419
- return self.backend.bk_reshape(res,[nout,nouty]+ishape[axis+2:])
866
+ return self.backend.bk_reshape(
867
+ res, [nout, nouty] + ishape[axis + 2 :]
868
+ )
420
869
  else:
421
- if len(ishape)==axis+2:
422
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty])
870
+ if len(ishape) == axis + 2:
871
+ return self.backend.bk_reshape(res, ishape[0:axis] + [nout, nouty])
423
872
  else:
424
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty]+ishape[axis+2:])
425
-
426
- return self.backend.bk_reshape(res,[nout,nouty])
427
-
873
+ return self.backend.bk_reshape(
874
+ res, ishape[0:axis] + [nout, nouty] + ishape[axis + 2 :]
875
+ )
876
+
877
+ return self.backend.bk_reshape(res, [nout, nouty])
878
+
879
+ elif self.use_1D:
880
+ ishape = list(im.shape)
881
+ if len(ishape) < axis + 1:
882
+ if not self.silent:
883
+ print("Use of 1D scat with data that has less than 1D")
884
+ return None
885
+
886
+ if ishape[axis] == nout:
887
+ return im
888
+
889
+ npix = im.shape[axis]
890
+ odata = 1
891
+ if len(ishape) > axis + 1:
892
+ for k in range(axis + 1, len(ishape)):
893
+ odata = odata * ishape[k]
894
+
895
+ ndata = 1
896
+ for k in range(axis):
897
+ ndata = ndata * ishape[k]
898
+
899
+ tim = self.backend.bk_reshape(
900
+ self.backend.bk_cast(im), [ndata, npix, odata]
901
+ )
902
+
903
+ while tim.shape[1] != nout:
904
+ res2 = self.backend.bk_expand_dims(
905
+ self.backend.bk_concat(
906
+ [(tim[:, 1:, :] + 3 * tim[:, :-1, :]) / 4, tim[:, -1:, :]], 1
907
+ ),
908
+ -2,
909
+ )
910
+ res1 = self.backend.bk_expand_dims(
911
+ self.backend.bk_concat(
912
+ [tim[:, 0:1, :], (tim[:, 1:, :] * 3 + tim[:, :-1, :]) / 4], 1
913
+ ),
914
+ -2,
915
+ )
916
+ tim = self.backend.bk_reshape(
917
+ self.backend.bk_concat([res1, res2], -2),
918
+ [ndata, tim.shape[1] * 2, odata],
919
+ )
920
+
921
+ if axis == 0:
922
+ if len(ishape) == 1:
923
+ return self.backend.bk_reshape(tim, [nout])
924
+ else:
925
+ return self.backend.bk_reshape(tim, [nout] + ishape[axis + 1 :])
926
+ else:
927
+ if len(ishape) == axis + 1:
928
+ return self.backend.bk_reshape(tim, ishape[0:axis] + [nout])
929
+ else:
930
+ return self.backend.bk_reshape(
931
+ tim, ishape[0:axis] + [nout] + ishape[axis + 1 :]
932
+ )
933
+
934
+ return self.backend.bk_reshape(tim, [nout])
935
+
428
936
  else:
429
937
 
430
- lout=int(np.sqrt(im.shape[axis]//12))
431
-
938
+ lout = int(np.sqrt(im.shape[axis] // 12))
939
+
432
940
  if self.pix_interp_val[lout][nout] is None:
433
- print('compute lout nout',lout,nout)
434
- th,ph=hp.pix2ang(nout,np.arange(12*nout**2,dtype='int'),nest=True)
435
- p, w = hp.get_interp_weights(lout,th,ph,nest=True)
941
+ if not self.silent:
942
+ print("compute lout nout", lout, nout)
943
+ th, ph = hp.pix2ang(
944
+ nout, np.arange(12 * nout**2, dtype="int"), nest=True
945
+ )
946
+ p, w = hp.get_interp_weights(lout, th, ph, nest=True)
436
947
  del th
437
948
  del ph
438
-
439
- indice=np.zeros([12*nout*nout*4,2],dtype='int')
440
- p=p.T
441
- w=w.T
442
- t=np.argsort(p,1).flatten() # to make oder indices for sparsematrix computation
443
- t=(t+np.repeat(np.arange(12*nout*nout)*4,4))
444
- p=p.flatten()[t]
445
- w=w.flatten()[t]
446
- indice[:,0]=np.repeat(np.arange(12*nout**2),4)
447
- indice[:,1]=p
448
-
449
- self.pix_interp_val[lout][nout]=1
450
- self.weight_interp_val[lout][nout] = self.backend.bk_SparseTensor(self.backend.constant(indice), \
451
- self.backend.constant(self.backend.bk_cast(w.flatten())), \
452
- dense_shape=[12*nout**2,12*lout**2])
453
-
454
- if lout==nout:
455
- imout=im
949
+
950
+ indice = np.zeros([12 * nout * nout * 4, 2], dtype="int")
951
+ p = p.T
952
+ w = w.T
953
+ t = np.argsort(
954
+ p, 1
955
+ ).flatten() # to make oder indices for sparsematrix computation
956
+ t = t + np.repeat(np.arange(12 * nout * nout) * 4, 4)
957
+ p = p.flatten()[t]
958
+ w = w.flatten()[t]
959
+ indice[:, 0] = np.repeat(np.arange(12 * nout**2), 4)
960
+ indice[:, 1] = p
961
+
962
+ self.pix_interp_val[lout][nout] = 1
963
+ self.weight_interp_val[lout][nout] = self.backend.bk_SparseTensor(
964
+ self.backend.constant(indice),
965
+ self.backend.constant(self.backend.bk_cast(w.flatten())),
966
+ dense_shape=[12 * nout**2, 12 * lout**2],
967
+ )
968
+
969
+ if lout == nout:
970
+ imout = im
456
971
  else:
457
972
 
458
- ishape=list(im.shape)
459
- odata=1
460
- for k in range(axis+1,len(ishape)):
461
- odata=odata*ishape[k]
462
-
463
- ndata=1
973
+ ishape = list(im.shape)
974
+ odata = 1
975
+ for k in range(axis + 1, len(ishape)):
976
+ odata = odata * ishape[k]
977
+
978
+ ndata = 1
464
979
  for k in range(axis):
465
- ndata=ndata*ishape[k]
466
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,12*lout**2,odata])
467
- if tim.dtype==self.all_cbk_type:
468
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
469
- ,self.backend.bk_real(tim[0])),[1,12*nout**2,odata])
470
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
471
- ,self.backend.bk_imag(tim[0])),[1,12*nout**2,odata])
472
- imout=self.backend.bk_complex(rr,ii)
980
+ ndata = ndata * ishape[k]
981
+ tim = self.backend.bk_reshape(
982
+ self.backend.bk_cast(im), [ndata, 12 * lout**2, odata]
983
+ )
984
+ if tim.dtype == self.all_cbk_type:
985
+ rr = self.backend.bk_reshape(
986
+ self.backend.bk_sparse_dense_matmul(
987
+ self.weight_interp_val[lout][nout],
988
+ self.backend.bk_real(tim[0]),
989
+ ),
990
+ [1, 12 * nout**2, odata],
991
+ )
992
+ ii = self.backend.bk_reshape(
993
+ self.backend.bk_sparse_dense_matmul(
994
+ self.weight_interp_val[lout][nout],
995
+ self.backend.bk_imag(tim[0]),
996
+ ),
997
+ [1, 12 * nout**2, odata],
998
+ )
999
+ imout = self.backend.bk_complex(rr, ii)
473
1000
  else:
474
- imout=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
475
- ,tim[0]),[1,12*nout**2,odata])
476
-
477
- for k in range(1,ndata):
478
- if tim.dtype==self.all_cbk_type:
479
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
480
- ,self.backend.bk_real(tim[k])),[1,12*nout**2,odata])
481
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
482
- ,self.backend.bk_imag(tim[k])),[1,12*nout**2,odata])
483
- imout=self.backend.bk_concat([imout,self.backend.bk_complex(rr,ii)],0)
1001
+ imout = self.backend.bk_reshape(
1002
+ self.backend.bk_sparse_dense_matmul(
1003
+ self.weight_interp_val[lout][nout], tim[0]
1004
+ ),
1005
+ [1, 12 * nout**2, odata],
1006
+ )
1007
+
1008
+ for k in range(1, ndata):
1009
+ if tim.dtype == self.all_cbk_type:
1010
+ rr = self.backend.bk_reshape(
1011
+ self.backend.bk_sparse_dense_matmul(
1012
+ self.weight_interp_val[lout][nout],
1013
+ self.backend.bk_real(tim[k]),
1014
+ ),
1015
+ [1, 12 * nout**2, odata],
1016
+ )
1017
+ ii = self.backend.bk_reshape(
1018
+ self.backend.bk_sparse_dense_matmul(
1019
+ self.weight_interp_val[lout][nout],
1020
+ self.backend.bk_imag(tim[k]),
1021
+ ),
1022
+ [1, 12 * nout**2, odata],
1023
+ )
1024
+ imout = self.backend.bk_concat(
1025
+ [imout, self.backend.bk_complex(rr, ii)], 0
1026
+ )
484
1027
  else:
485
- imout=self.backend.bk_concat([imout,self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
486
- ,tim[k]),[1,12*nout**2,odata])],0)
487
-
488
- if axis==0:
489
- if len(ishape)==1:
490
- return self.backend.bk_reshape(imout,[12*nout**2])
1028
+ imout = self.backend.bk_concat(
1029
+ [
1030
+ imout,
1031
+ self.backend.bk_reshape(
1032
+ self.backend.bk_sparse_dense_matmul(
1033
+ self.weight_interp_val[lout][nout], tim[k]
1034
+ ),
1035
+ [1, 12 * nout**2, odata],
1036
+ ),
1037
+ ],
1038
+ 0,
1039
+ )
1040
+
1041
+ if axis == 0:
1042
+ if len(ishape) == 1:
1043
+ return self.backend.bk_reshape(imout, [12 * nout**2])
491
1044
  else:
492
- return self.backend.bk_reshape(imout,[12*nout**2]+ishape[axis+1:])
1045
+ return self.backend.bk_reshape(
1046
+ imout, [12 * nout**2] + ishape[axis + 1 :]
1047
+ )
493
1048
  else:
494
- if len(ishape)==axis+1:
495
- return self.backend.bk_reshape(imout,ishape[0:axis]+[12*nout**2])
1049
+ if len(ishape) == axis + 1:
1050
+ return self.backend.bk_reshape(
1051
+ imout, ishape[0:axis] + [12 * nout**2]
1052
+ )
496
1053
  else:
497
- return self.backend.bk_reshape(imout,ishape[0:axis]+[12*nout**2]+ishape[axis+1:])
498
- return(imout)
499
-
500
- #--------------------------------------------------------
501
- def fill_1d(self,i_arr,nullval=0):
502
- arr=i_arr.copy()
1054
+ return self.backend.bk_reshape(
1055
+ imout, ishape[0:axis] + [12 * nout**2] + ishape[axis + 1 :]
1056
+ )
1057
+ return imout
1058
+
1059
+ # --------------------------------------------------------
1060
+ def fill_1d(self, i_arr, nullval=0):
1061
+ arr = i_arr.copy()
503
1062
  # Indices des éléments non nuls
504
- non_zero_indices = np.where(arr!=nullval)[0]
505
-
1063
+ non_zero_indices = np.where(arr != nullval)[0]
1064
+
506
1065
  # Indices de tous les éléments
507
1066
  all_indices = np.arange(len(arr))
508
-
1067
+
509
1068
  # Interpoler linéairement en utilisant np.interp
510
1069
  # np.interp(x, xp, fp) : x sont les indices pour lesquels on veut obtenir des valeurs
511
1070
  # xp sont les indices des données existantes, fp sont les valeurs des données existantes
512
- interpolated_values = np.interp(all_indices, non_zero_indices, arr[non_zero_indices])
513
-
1071
+ interpolated_values = np.interp(
1072
+ all_indices, non_zero_indices, arr[non_zero_indices]
1073
+ )
1074
+
514
1075
  # Mise à jour du tableau original
515
- arr[arr==nullval] = interpolated_values[arr==nullval]
516
-
1076
+ arr[arr == nullval] = interpolated_values[arr == nullval]
1077
+
517
1078
  return arr
518
1079
 
519
- def fill_2d(self,i_arr,nullval=0):
520
- arr=i_arr.copy()
1080
+ def fill_2d(self, i_arr, nullval=0):
1081
+ arr = i_arr.copy()
521
1082
  # Créer une grille de coordonnées correspondant aux indices du tableau
522
1083
  x, y = np.indices(arr.shape)
523
-
1084
+
524
1085
  # Extraire les coordonnées des points non nuls ainsi que leurs valeurs
525
1086
  non_zero_points = np.array((x[arr != nullval], y[arr != nullval])).T
526
1087
  non_zero_values = arr[arr != nullval]
527
-
1088
+
528
1089
  # Extraire les coordonnées des points nuls
529
1090
  zero_points = np.array((x[arr == nullval], y[arr == nullval])).T
530
1091
 
531
1092
  # Interpolation linéaire
532
- interpolated_values = griddata(non_zero_points, non_zero_values, zero_points, method='linear')
1093
+ interpolated_values = griddata(
1094
+ non_zero_points, non_zero_values, zero_points, method="linear"
1095
+ )
533
1096
 
534
1097
  # Remplacer les valeurs nulles par les valeurs interpolées
535
1098
  arr[arr == nullval] = interpolated_values
536
1099
 
537
1100
  return arr
538
-
539
- def fill_healpy(self,i_map,nmax=10,nullval=hp.UNSEEN):
540
- map=1*i_map
1101
+
1102
+ def fill_healpy(self, i_map, nmax=10, nullval=hp.UNSEEN):
1103
+ map = 1 * i_map
541
1104
  # Trouver les pixels nuls
542
1105
  nside = hp.npix2nside(len(map))
543
1106
  null_indices = np.where(map == nullval)[0]
544
-
545
- itt=0
546
- while null_indices.shape[0]>0 and itt<nmax:
1107
+
1108
+ itt = 0
1109
+ while null_indices.shape[0] > 0 and itt < nmax:
547
1110
  # Trouver les coordonnées theta, phi pour les pixels nuls
548
1111
  theta, phi = hp.pix2ang(nside, null_indices)
549
-
1112
+
550
1113
  # Interpoler les valeurs en utilisant les pixels voisins
551
1114
  # La fonction get_interp_val peut être utilisée pour obtenir les valeurs interpolées
552
1115
  # pour des positions données en theta et phi.
553
1116
  i_idx = hp.get_all_neighbours(nside, theta, phi)
554
-
555
- i_w=(map[i_idx]!=nullval)*(i_idx!=-1)
556
- vv=np.sum(i_w,0)
557
- interpolated_values=np.sum(i_w*map[i_idx],0)
1117
+
1118
+ i_w = (map[i_idx] != nullval) * (i_idx != -1)
1119
+ vv = np.sum(i_w, 0)
1120
+ interpolated_values = np.sum(i_w * map[i_idx], 0)
558
1121
 
559
1122
  # Remplacer les valeurs nulles par les valeurs interpolées
560
- map[null_indices[vv>0]] = interpolated_values[vv>0]/vv[vv>0]
1123
+ map[null_indices[vv > 0]] = interpolated_values[vv > 0] / vv[vv > 0]
561
1124
 
562
1125
  null_indices = np.where(map == nullval)[0]
563
- itt+=1
564
-
1126
+ itt += 1
1127
+
565
1128
  return map
566
-
567
- #--------------------------------------------------------
568
- def ud_grade_1d(self,im,nout,axis=0):
569
- npix=im.shape[axis]
570
-
571
- ishape=list(im.shape)
572
- odata=1
573
- for k in range(axis+1,len(ishape)):
574
- odata=odata*ishape[k]
575
-
576
- ndata=1
577
- for k in range(axis):
578
- ndata=ndata*ishape[k]
579
1129
 
580
- nscale=npix//nout
581
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix//nscale,nscale,odata])
1130
+ # --------------------------------------------------------
1131
+ def ud_grade_1d(self, im, nout, axis=0):
1132
+ npix = im.shape[axis]
582
1133
 
583
- res = self.backend.bk_reduce_mean(tim,2)
584
-
585
- if axis==0:
586
- if len(ishape)==1:
587
- return self.backend.bk_reshape(res,[nout])
1134
+ ishape = list(im.shape)
1135
+ odata = 1
1136
+ for k in range(axis + 1, len(ishape)):
1137
+ odata = odata * ishape[k]
1138
+
1139
+ ndata = 1
1140
+ for k in range(axis):
1141
+ ndata = ndata * ishape[k]
1142
+
1143
+ nscale = npix // nout
1144
+ if npix % nscale == 0:
1145
+ tim = self.backend.bk_reshape(
1146
+ self.backend.bk_cast(im), [ndata, npix // nscale, nscale, odata]
1147
+ )
1148
+ else:
1149
+ im = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1150
+ tim = self.backend.bk_reshape(
1151
+ self.backend.bk_cast(im[:, 0 : nscale * (npix // nscale), :]),
1152
+ [ndata, npix // nscale, nscale, odata],
1153
+ )
1154
+ res = self.backend.bk_reduce_mean(tim, 2)
1155
+
1156
+ if axis == 0:
1157
+ if len(ishape) == 1:
1158
+ return self.backend.bk_reshape(res, [nout])
588
1159
  else:
589
- return self.backend.bk_reshape(res,[nout]+ishape[axis+1:])
1160
+ return self.backend.bk_reshape(res, [nout] + ishape[axis + 1 :])
590
1161
  else:
591
- if len(ishape)==axis+1:
592
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout])
1162
+ if len(ishape) == axis + 1:
1163
+ return self.backend.bk_reshape(res, ishape[0:axis] + [nout])
593
1164
  else:
594
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout]+ishape[axis+1:])
595
- return self.backend.bk_reshape(res,[nout])
596
-
597
- #--------------------------------------------------------
598
- def up_grade_2_1d(self,im,axis=0):
599
-
600
- npix=im.shape[axis]
601
-
602
- ishape=list(im.shape)
603
- odata=1
604
- for k in range(axis+1,len(ishape)):
605
- odata=odata*ishape[k]
606
-
607
- ndata=1
1165
+ return self.backend.bk_reshape(
1166
+ res, ishape[0:axis] + [nout] + ishape[axis + 1 :]
1167
+ )
1168
+ return self.backend.bk_reshape(res, [nout])
1169
+
1170
+ # --------------------------------------------------------
1171
+ def up_grade_2_1d(self, im, axis=0):
1172
+
1173
+ npix = im.shape[axis]
1174
+
1175
+ ishape = list(im.shape)
1176
+ odata = 1
1177
+ for k in range(axis + 1, len(ishape)):
1178
+ odata = odata * ishape[k]
1179
+
1180
+ ndata = 1
608
1181
  for k in range(axis):
609
- ndata=ndata*ishape[k]
610
-
611
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
612
-
613
- res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
614
- res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
615
- res = self.backend.bk_concat([res1,res2],-2)
616
-
617
- if axis==0:
618
- if len(ishape)==1:
619
- return self.backend.bk_reshape(res,[npix*2])
1182
+ ndata = ndata * ishape[k]
1183
+
1184
+ tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1185
+
1186
+ res2 = self.backend.bk_expand_dims(
1187
+ self.backend.bk_concat(
1188
+ [(tim[:, 1:, :] + 3 * tim[:, :-1, :]) / 4, tim[:, -1:, :]], 1
1189
+ ),
1190
+ -2,
1191
+ )
1192
+ res1 = self.backend.bk_expand_dims(
1193
+ self.backend.bk_concat(
1194
+ [tim[:, 0:1, :], (tim[:, 1:, :] * 3 + tim[:, :-1, :]) / 4], 1
1195
+ ),
1196
+ -2,
1197
+ )
1198
+ res = self.backend.bk_concat([res1, res2], -2)
1199
+
1200
+ if axis == 0:
1201
+ if len(ishape) == 1:
1202
+ return self.backend.bk_reshape(res, [npix * 2])
620
1203
  else:
621
- return self.backend.bk_reshape(res,[npix*2]+ishape[axis+1:])
1204
+ return self.backend.bk_reshape(res, [npix * 2] + ishape[axis + 1 :])
622
1205
  else:
623
- if len(ishape)==axis+1:
624
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix*2])
1206
+ if len(ishape) == axis + 1:
1207
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix * 2])
625
1208
  else:
626
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix*2]+ishape[axis+1:])
627
- return self.backend.bk_reshape(res,[npix*2])
628
-
629
-
630
- #--------------------------------------------------------
631
- def convol_1d(self,im,axis=0):
632
-
633
- xx=np.arange(5)-2
634
- w=np.exp(-0.17328679514*(xx)**2)
635
- c=np.cos((xx)*np.pi/2)
636
- s=np.sin((xx)*np.pi/2)
637
-
638
- wr=np.array(w*c).reshape(xx.shape[0],1,1)
639
- wi=np.array(w*s).reshape(xx.shape[0],1,1)
640
-
641
- npix=im.shape[axis]
642
-
643
- ishape=list(im.shape)
644
- odata=1
645
- for k in range(axis+1,len(ishape)):
646
- odata=odata*ishape[k]
647
-
648
- ndata=1
1209
+ return self.backend.bk_reshape(
1210
+ res, ishape[0:axis] + [npix * 2] + ishape[axis + 1 :]
1211
+ )
1212
+ return self.backend.bk_reshape(res, [npix * 2])
1213
+
1214
+ # --------------------------------------------------------
1215
+ def convol_1d(self, im, axis=0):
1216
+
1217
+ xx = np.arange(5) - 2
1218
+ w = np.exp(-0.17328679514 * (xx) ** 2)
1219
+ c = np.cos((xx) * np.pi / 2)
1220
+ s = np.sin((xx) * np.pi / 2)
1221
+
1222
+ wr = np.array(w * c).reshape(xx.shape[0], 1, 1)
1223
+ wi = np.array(w * s).reshape(xx.shape[0], 1, 1)
1224
+
1225
+ npix = im.shape[axis]
1226
+
1227
+ ishape = list(im.shape)
1228
+ odata = 1
1229
+ for k in range(axis + 1, len(ishape)):
1230
+ odata = odata * ishape[k]
1231
+
1232
+ ndata = 1
649
1233
  for k in range(axis):
650
- ndata=ndata*ishape[k]
651
-
652
- if odata>1:
653
- wr=np.repeat(wr,odata,2)
654
- wi=np.repeat(wi,odata,2)
655
-
656
- wr=self.backend.bk_cast(self.backend.constant(wr))
657
- wi=self.backend.bk_cast(self.backend.constant(wi))
658
-
659
- tim = self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
660
-
661
- if tim.dtype==self.all_cbk_type:
662
- rr1 = self.backend.bk_conv1d(self.backend.bk_real(tim),wr)
663
- ii1 = self.backend.bk_conv1d(self.backend.bk_real(tim),wi)
664
- rr2 = self.backend.bk_conv1d(self.backend.bk_imag(tim),wr)
665
- ii2 = self.backend.bk_conv1d(self.backend.bk_imag(tim),wi)
666
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
1234
+ ndata = ndata * ishape[k]
1235
+
1236
+ if odata > 1:
1237
+ wr = np.repeat(wr, odata, 2)
1238
+ wi = np.repeat(wi, odata, 2)
1239
+
1240
+ wr = self.backend.bk_cast(self.backend.constant(wr))
1241
+ wi = self.backend.bk_cast(self.backend.constant(wi))
1242
+
1243
+ tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1244
+
1245
+ if tim.dtype == self.all_cbk_type:
1246
+ rr1 = self.backend.bk_conv1d(self.backend.bk_real(tim), wr)
1247
+ ii1 = self.backend.bk_conv1d(self.backend.bk_real(tim), wi)
1248
+ rr2 = self.backend.bk_conv1d(self.backend.bk_imag(tim), wr)
1249
+ ii2 = self.backend.bk_conv1d(self.backend.bk_imag(tim), wi)
1250
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
667
1251
  else:
668
- rr = self.backend.bk_conv1d(tim,wr)
669
- ii = self.backend.bk_conv1d(tim,wi)
670
-
671
- res=self.backend.bk_complex(rr,ii)
672
-
673
- if axis==0:
674
- if len(ishape)==1:
675
- return self.backend.bk_reshape(res,[npix])
1252
+ rr = self.backend.bk_conv1d(tim, wr)
1253
+ ii = self.backend.bk_conv1d(tim, wi)
1254
+
1255
+ res = self.backend.bk_complex(rr, ii)
1256
+
1257
+ if axis == 0:
1258
+ if len(ishape) == 1:
1259
+ return self.backend.bk_reshape(res, [npix])
676
1260
  else:
677
- return self.backend.bk_reshape(res,[npix]+ishape[axis+1:])
1261
+ return self.backend.bk_reshape(res, [npix] + ishape[axis + 1 :])
678
1262
  else:
679
- if len(ishape)==axis+1:
680
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix])
1263
+ if len(ishape) == axis + 1:
1264
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix])
681
1265
  else:
682
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix]+ishape[axis+1:])
683
- return self.backend.bk_reshape(res,[npix])
684
-
685
-
686
- #--------------------------------------------------------
687
- def smooth_1d(self,im,axis=0):
688
-
689
- xx=np.arange(5)-2
690
- w=np.exp(-0.17328679514*(xx)**2)
691
- w=w/w.sum()
692
- w=np.array(w).reshape(xx.shape[0],1,1)
693
-
694
- npix=im.shape[axis]
695
-
696
- ishape=list(im.shape)
697
- odata=1
698
- for k in range(axis+1,len(ishape)):
699
- odata=odata*ishape[k]
700
-
701
- ndata=1
1266
+ return self.backend.bk_reshape(
1267
+ res, ishape[0:axis] + [npix] + ishape[axis + 1 :]
1268
+ )
1269
+ return self.backend.bk_reshape(res, [npix])
1270
+
1271
+ # --------------------------------------------------------
1272
+ def smooth_1d(self, im, axis=0):
1273
+
1274
+ xx = np.arange(5) - 2
1275
+ w = np.exp(-0.17328679514 * (xx) ** 2)
1276
+ w = w / w.sum()
1277
+ w = np.array(w).reshape(xx.shape[0], 1, 1)
1278
+
1279
+ npix = im.shape[axis]
1280
+
1281
+ ishape = list(im.shape)
1282
+ odata = 1
1283
+ for k in range(axis + 1, len(ishape)):
1284
+ odata = odata * ishape[k]
1285
+
1286
+ ndata = 1
702
1287
  for k in range(axis):
703
- ndata=ndata*ishape[k]
704
-
705
- if odata>1:
706
- w=np.repeat(w,odata,2)
707
-
708
- w=self.backend.bk_cast(self.backend.constant(w))
709
-
710
- tim = self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
711
-
712
- if tim.dtype==self.all_cbk_type:
713
- rr = self.backend.bk_conv1d(self.backend.bk_real(tim),w)
714
- ii = self.backend.bk_conv1d(self.backend.bk_real(tim),w)
715
- res=self.backend.bk_complex(rr,ii)
1288
+ ndata = ndata * ishape[k]
1289
+
1290
+ if odata > 1:
1291
+ w = np.repeat(w, odata, 2)
1292
+
1293
+ w = self.backend.bk_cast(self.backend.constant(w))
1294
+
1295
+ tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1296
+
1297
+ if tim.dtype == self.all_cbk_type:
1298
+ rr = self.backend.bk_conv1d(self.backend.bk_real(tim), w)
1299
+ ii = self.backend.bk_conv1d(self.backend.bk_real(tim), w)
1300
+ res = self.backend.bk_complex(rr, ii)
716
1301
  else:
717
- res=self.backend.bk_conv1d(tim,w)
718
-
719
- if axis==0:
720
- if len(ishape)==1:
721
- return self.backend.bk_reshape(res,[npix])
1302
+ res = self.backend.bk_conv1d(tim, w)
1303
+
1304
+ if axis == 0:
1305
+ if len(ishape) == 1:
1306
+ return self.backend.bk_reshape(res, [npix])
722
1307
  else:
723
- return self.backend.bk_reshape(res,[npix]+ishape[axis+1:])
1308
+ return self.backend.bk_reshape(res, [npix] + ishape[axis + 1 :])
724
1309
  else:
725
- if len(ishape)==axis+1:
726
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix])
1310
+ if len(ishape) == axis + 1:
1311
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix])
727
1312
  else:
728
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix]+ishape[axis+1:])
729
- return self.backend.bk_reshape(res,[npix])
730
-
731
- #--------------------------------------------------------
732
- def up_grade_1d(self,im,nout,axis=0):
733
-
734
- lout=int(im.shape[axis])
735
- nscale=int(np.log(nout//lout)/np.log(2))
736
- res=self.backend.bk_cast(im)
1313
+ return self.backend.bk_reshape(
1314
+ res, ishape[0:axis] + [npix] + ishape[axis + 1 :]
1315
+ )
1316
+ return self.backend.bk_reshape(res, [npix])
1317
+
1318
+ # --------------------------------------------------------
1319
+ def up_grade_1d(self, im, nout, axis=0):
1320
+
1321
+ lout = int(im.shape[axis])
1322
+ nscale = int(np.log(nout // lout) / np.log(2))
1323
+ res = self.backend.bk_cast(im)
737
1324
  for k in range(nscale):
738
- res=self.up_grade_2_1d(res,axis=axis)
739
- return(res)
740
-
1325
+ res = self.up_grade_2_1d(res, axis=axis)
1326
+ return res
1327
+
741
1328
  # ---------------------------------------------−---------
742
- def init_index(self,nside,kernel=-1):
1329
+ def init_index(self, nside, kernel=-1):
743
1330
 
744
- if kernel==-1:
745
- l_kernel=self.KERNELSZ
1331
+ if kernel == -1:
1332
+ l_kernel = self.KERNELSZ
746
1333
  else:
747
- l_kernel=kernel
748
-
749
-
1334
+ l_kernel = kernel
1335
+
750
1336
  try:
751
1337
  if self.use_2D:
752
- tmp=np.load('%s/W%d_%s_%d_IDX.npy'%(self.TEMPLATE_PATH,l_kernel**2,TMPFILE_VERSION,nside))
1338
+ tmp = np.load(
1339
+ "%s/W%d_%s_%d_IDX.npy"
1340
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1341
+ )
753
1342
  else:
754
- tmp=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel**2,self.NORIENT,nside))
1343
+ tmp = np.load(
1344
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1345
+ % (
1346
+ self.TEMPLATE_PATH,
1347
+ TMPFILE_VERSION,
1348
+ l_kernel**2,
1349
+ self.NORIENT,
1350
+ nside,
1351
+ )
1352
+ )
755
1353
  except:
756
- if self.use_2D==False:
1354
+ if not self.use_2D:
1355
+
1356
+ if l_kernel == 5:
1357
+ pw = 0.5
1358
+ pw2 = 0.5
1359
+ threshold = 2e-4
1360
+
1361
+ elif l_kernel == 3:
1362
+ pw = 1.0 / np.sqrt(2)
1363
+ pw2 = 1.0
1364
+ threshold = 1e-3
1365
+
1366
+ elif l_kernel == 7:
1367
+ pw = 0.5
1368
+ pw2 = 0.25
1369
+ threshold = 4e-5
1370
+
1371
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1372
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1373
+
1374
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1375
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1376
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1377
+
1378
+ indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1379
+ indice = np.zeros(
1380
+ [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1381
+ )
1382
+ wav = np.zeros(
1383
+ [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1384
+ )
1385
+ wwav = np.zeros([12 * nside * nside * 64 * self.NORIENT], dtype="float")
1386
+
1387
+ iv = 0
1388
+ iv2 = 0
1389
+ for iii in range(12 * nside * nside):
1390
+
1391
+ if iii % (nside * nside) == nside * nside - 1:
1392
+ if not self.silent:
1393
+ print(
1394
+ "Pre-compute nside=%6d %.2f%%"
1395
+ % (nside, 100 * iii / (12 * nside * nside))
1396
+ )
1397
+
1398
+ hidx = hp.query_disc(
1399
+ nside, [x[iii], y[iii], z[iii]], 2 * np.pi / nside, nest=True
1400
+ )
1401
+
1402
+ R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1403
+
1404
+ t2, p2 = R(th[hidx], ph[hidx])
1405
+
1406
+ vec2 = hp.ang2vec(t2, p2)
1407
+
1408
+ x2 = vec2[:, 0]
1409
+ y2 = vec2[:, 1]
1410
+ z2 = vec2[:, 2]
1411
+
1412
+ ww = np.exp(
1413
+ -pw2
1414
+ * ((nside) ** 2)
1415
+ * ((x2) ** 2 + (y2) ** 2 + (z2 - 1.0) ** 2)
1416
+ )
1417
+ idx = np.where((ww**2) > threshold)[0]
1418
+ nval2 = len(idx)
1419
+ indice2[iv2 : iv2 + nval2, 0] = iii
1420
+ indice2[iv2 : iv2 + nval2, 1] = hidx[idx]
1421
+ wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1422
+ iv2 += nval2
1423
+
1424
+ for l_rotation in range(self.NORIENT):
1425
+
1426
+ angle = (
1427
+ l_rotation / 4.0 * np.pi
1428
+ - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1429
+ - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1430
+ )
1431
+
1432
+ # posi=2*(0.5-(z[hidx]<0))
1433
+
1434
+ axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1435
+ wresr = ww * np.cos(pw * axes * (nside) * np.pi)
1436
+ wresi = ww * np.sin(pw * axes * (nside) * np.pi)
1437
+
1438
+ vnorm = wresr * wresr + wresi * wresi
1439
+ idx = np.where(vnorm > threshold)[0]
1440
+
1441
+ nval = len(idx)
1442
+ indice[iv : iv + nval, 0] = iii * 4 + l_rotation
1443
+ indice[iv : iv + nval, 1] = hidx[idx]
1444
+ # print([hidx[k] for k in idx])
1445
+ # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1446
+ normr = np.mean(wresr[idx])
1447
+ normi = np.mean(wresi[idx])
1448
+
1449
+ val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1450
+ val = val / abs(val).sum()
1451
+
1452
+ wav[iv : iv + nval] = val
1453
+ iv += nval
1454
+
1455
+ indice = indice[:iv, :]
1456
+ wav = wav[:iv]
1457
+ indice2 = indice2[:iv2, :]
1458
+ wwav = wwav[:iv2]
1459
+ if not self.silent:
1460
+ print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1461
+ """
1462
+ # OLD VERSION OLD VERSION OLD VERSION (3.0)
757
1463
  if self.KERNELSZ*self.KERNELSZ>12*nside*nside:
758
- l_kernel=2*nside
759
-
1464
+ l_kernel=3
1465
+
760
1466
  aa=np.cos(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
761
1467
  bb=np.sin(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
762
1468
  x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
@@ -773,36 +1479,42 @@ class FoCUS:
773
1479
  lidx=np.arange(12*nside*nside)
774
1480
 
775
1481
  pw=np.pi/4.0
776
- pw2=1/2.0
777
-
1482
+ pw2=1/2
1483
+ amp=1.0
1484
+
778
1485
  if l_kernel==5:
779
1486
  pw=np.pi/4.0
780
- pw2=1/2.0
1487
+ pw2=1/2.25
1488
+ amp=1.0/9.2038
1489
+
781
1490
  elif l_kernel==3:
782
- pw=1.0
1491
+ pw=1.0/np.sqrt(2)
783
1492
  pw2=1.0
1493
+ amp=1/8.45
1494
+
784
1495
  elif l_kernel==7:
785
1496
  pw=np.pi/4.0
786
1497
  pw2=1.0/3.0
787
-
1498
+
788
1499
  for k in range(12*nside*nside):
789
1500
  if k%(nside*nside)==0:
790
- print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
1501
+ if not self.silent:
1502
+ print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
791
1503
  if nside>scale*2:
792
1504
  lidx=hp.get_all_neighbours(nside//scale,th[k//(scale*scale)],ph[k//(scale*scale)],nest=True)
793
1505
  lidx=np.concatenate([lidx,np.array([(k//(scale*scale))])],0)
794
1506
  lidx=np.repeat(lidx*(scale*scale),(scale*scale))+ \
795
1507
  np.tile(np.arange((scale*scale)),lidx.shape[0])
796
-
1508
+
797
1509
  delta=(x[lidx]-x[k])**2+(y[lidx]-y[k])**2+(z[lidx]-z[k])**2
798
1510
  pidx=np.where(delta<(10)/(nside**2))[0]
799
1511
  if len(pidx)<l_kernel**2:
800
1512
  pidx=np.arange(delta.shape[0])
801
-
1513
+
802
1514
  w=np.exp(-pw2*delta[pidx]*(nside**2))
803
1515
  pidx=pidx[np.argsort(-w)[0:l_kernel**2]]
804
1516
  pidx=pidx[np.argsort(lidx[pidx])]
805
-
1517
+
806
1518
  w=np.exp(-pw2*delta[pidx]*(nside**2))
807
1519
  iwav[k]=lidx[pidx]
808
1520
  wwav[k]=w
@@ -810,16 +1522,16 @@ class FoCUS:
810
1522
  r=hp.Rotator(rot=rot)
811
1523
  ty,tx=r(to[iwav[k]],po[iwav[k]])
812
1524
  ty=ty-np.pi/2
813
-
1525
+
814
1526
  xx=np.expand_dims(pw*nside*np.pi*tx/np.cos(ty),-1)
815
1527
  yy=np.expand_dims(pw*nside*np.pi*ty,-1)
816
-
1528
+
817
1529
  wav[k,:,:]=(np.cos(xx*aa+yy*bb)+complex(0.0,1.0)*np.sin(xx*aa+yy*bb))*np.expand_dims(w,-1)
818
-
1530
+
819
1531
  wav=wav-np.expand_dims(np.mean(wav,1),1)
820
- wav=wav/np.expand_dims(np.std(wav,1),1)
1532
+ wav=amp*wav/np.expand_dims(np.std(wav,1),1)
821
1533
  wwav=wwav/np.expand_dims(np.sum(wwav,1),1)
822
-
1534
+
823
1535
  nk=l_kernel*l_kernel
824
1536
  indice=np.zeros([12*nside*nside*nk*self.NORIENT,2],dtype='int')
825
1537
  lidx=np.arange(self.NORIENT)
@@ -831,561 +1543,1164 @@ class FoCUS:
831
1543
  for i in range(12*nside*nside):
832
1544
  indice2[i*nk:i*nk+nk,0]=i
833
1545
  indice2[i*nk:i*nk+nk,1]=iwav[i]
834
-
1546
+
835
1547
  w=np.zeros([12*nside*nside,wav.shape[2],wav.shape[1]],dtype='complex')
836
1548
  for i in range(wav.shape[1]):
837
1549
  for j in range(wav.shape[2]):
838
1550
  w[:,j,i]=wav[:,i,j]
839
1551
  wav=w.flatten()
840
1552
  wwav=wwav.flatten()
841
-
842
- print('Write FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
843
- np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice)
844
- np.save('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),wav)
845
- np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice2)
846
- np.save('%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),wwav)
1553
+ """
1554
+ if not self.silent:
1555
+ print(
1556
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1557
+ % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1558
+ )
1559
+ np.save(
1560
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1561
+ % (
1562
+ self.TEMPLATE_PATH,
1563
+ TMPFILE_VERSION,
1564
+ self.KERNELSZ**2,
1565
+ self.NORIENT,
1566
+ nside,
1567
+ ),
1568
+ indice,
1569
+ )
1570
+ np.save(
1571
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1572
+ % (
1573
+ self.TEMPLATE_PATH,
1574
+ TMPFILE_VERSION,
1575
+ self.KERNELSZ**2,
1576
+ self.NORIENT,
1577
+ nside,
1578
+ ),
1579
+ wav,
1580
+ )
1581
+ np.save(
1582
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1583
+ % (
1584
+ self.TEMPLATE_PATH,
1585
+ TMPFILE_VERSION,
1586
+ self.KERNELSZ**2,
1587
+ self.NORIENT,
1588
+ nside,
1589
+ ),
1590
+ indice2,
1591
+ )
1592
+ np.save(
1593
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1594
+ % (
1595
+ self.TEMPLATE_PATH,
1596
+ TMPFILE_VERSION,
1597
+ self.KERNELSZ**2,
1598
+ self.NORIENT,
1599
+ nside,
1600
+ ),
1601
+ wwav,
1602
+ )
847
1603
  else:
848
- if l_kernel**2==9:
849
- if self.rank==0:
1604
+ if l_kernel**2 == 9:
1605
+ if self.rank == 0:
850
1606
  self.comp_idx_w9(nside)
851
- elif l_kernel**2==25:
852
- if self.rank==0:
1607
+ elif l_kernel**2 == 25:
1608
+ if self.rank == 0:
853
1609
  self.comp_idx_w25(nside)
854
1610
  else:
855
- if self.rank==0:
856
- print('Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d'%(KERNELSZ,KERNELSZ))
857
- exit(0)
858
-
859
- self.barrier()
860
- if self.use_2D:
861
- tmp=np.load('%s/W%d_%s_%d_IDX.npy'%(self.TEMPLATE_PATH,l_kernel**2,TMPFILE_VERSION,nside))
1611
+ if self.rank == 0:
1612
+ if not self.silent:
1613
+ print(
1614
+ "Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d"
1615
+ % (self.KERNELSZ, self.KERNELSZ)
1616
+ )
1617
+ return None
1618
+
1619
+ self.barrier()
1620
+ if self.use_2D:
1621
+ tmp = np.load(
1622
+ "%s/W%d_%s_%d_IDX.npy"
1623
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1624
+ )
862
1625
  else:
863
- tmp=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
864
- tmp2=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
865
- wr=np.load('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside)).real
866
- wi=np.load('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside)).imag
867
- ws=self.slope*np.load('%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
868
-
869
- wr=self.backend.bk_SparseTensor(self.backend.constant(tmp),self.backend.constant(self.backend.bk_cast(wr)),dense_shape=[12*nside**2*self.NORIENT,12*nside**2])
870
- wi=self.backend.bk_SparseTensor(self.backend.constant(tmp),self.backend.constant(self.backend.bk_cast(wi)),dense_shape=[12*nside**2*self.NORIENT,12*nside**2])
871
- ws=self.backend.bk_SparseTensor(self.backend.constant(tmp2),self.backend.constant(self.backend.bk_cast(ws)),dense_shape=[12*nside**2,12*nside**2])
872
-
873
- if kernel==-1:
874
- self.Idx_Neighbours[nside]=tmp
875
-
1626
+ tmp = np.load(
1627
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1628
+ % (
1629
+ self.TEMPLATE_PATH,
1630
+ TMPFILE_VERSION,
1631
+ self.KERNELSZ**2,
1632
+ self.NORIENT,
1633
+ nside,
1634
+ )
1635
+ )
1636
+ tmp2 = np.load(
1637
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1638
+ % (
1639
+ self.TEMPLATE_PATH,
1640
+ TMPFILE_VERSION,
1641
+ self.KERNELSZ**2,
1642
+ self.NORIENT,
1643
+ nside,
1644
+ )
1645
+ )
1646
+ wr = np.load(
1647
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1648
+ % (
1649
+ self.TEMPLATE_PATH,
1650
+ TMPFILE_VERSION,
1651
+ self.KERNELSZ**2,
1652
+ self.NORIENT,
1653
+ nside,
1654
+ )
1655
+ ).real
1656
+ wi = np.load(
1657
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1658
+ % (
1659
+ self.TEMPLATE_PATH,
1660
+ TMPFILE_VERSION,
1661
+ self.KERNELSZ**2,
1662
+ self.NORIENT,
1663
+ nside,
1664
+ )
1665
+ ).imag
1666
+ ws = self.slope * np.load(
1667
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1668
+ % (
1669
+ self.TEMPLATE_PATH,
1670
+ TMPFILE_VERSION,
1671
+ self.KERNELSZ**2,
1672
+ self.NORIENT,
1673
+ nside,
1674
+ )
1675
+ )
1676
+
1677
+ wr = self.backend.bk_SparseTensor(
1678
+ self.backend.constant(tmp),
1679
+ self.backend.constant(self.backend.bk_cast(wr)),
1680
+ dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1681
+ )
1682
+ wi = self.backend.bk_SparseTensor(
1683
+ self.backend.constant(tmp),
1684
+ self.backend.constant(self.backend.bk_cast(wi)),
1685
+ dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1686
+ )
1687
+ ws = self.backend.bk_SparseTensor(
1688
+ self.backend.constant(tmp2),
1689
+ self.backend.constant(self.backend.bk_cast(ws)),
1690
+ dense_shape=[12 * nside**2, 12 * nside**2],
1691
+ )
1692
+
1693
+ if kernel == -1:
1694
+ self.Idx_Neighbours[nside] = tmp
1695
+
876
1696
  if self.use_2D:
877
- if kernel!=-1:
1697
+ if kernel != -1:
878
1698
  return tmp
879
-
880
- return wr,wi,ws,tmp
881
-
882
- # ---------------------------------------------−---------
883
- # Compute x [....,a,....] to [....,a*a,....]
884
- #NOT YET TESTED OR IMPLEMENTED
885
- def auto_cross_2(x,axis=0):
886
- shape=np.array(x.shape)
887
- if axis==0:
888
- y1=self.reshape(x,[shape[0],1,np.cumprod(shape[1:])])
889
- y2=self.reshape(x,[1,shape[0],np.cumprod(shape[1:])])
890
- oshape=np.concat([shape[0],shape[0],shape[1:]])
891
- return(self.reshape(y1*y2,oshape))
892
-
893
- # ---------------------------------------------−---------
894
- # Compute x [....,a,....,b,....] to [....,b*b,....,a*a,....]
895
- #NOT YET TESTED OR IMPLEMENTED
896
- def auto_cross_2(x,axis1=0,axis2=1):
897
- shape=np.array(x.shape)
898
- if axis==0:
899
- y1=self.reshape(x,[shape[0],1,np.cumprod(shape[1:])])
900
- y2=self.reshape(x,[1,shape[0],np.cumprod(shape[1:])])
901
- oshape=np.concat([shape[0],shape[0],shape[1:]])
902
- return(self.reshape(y1*y2,oshape))
903
-
904
-
1699
+
1700
+ return wr, wi, ws, tmp
1701
+
905
1702
  # ---------------------------------------------−---------
906
1703
  # convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
907
- def swapaxes(self,x,axis1,axis2):
908
- shape=list(x.shape)
909
- if axis1<0:
910
- laxis1=len(shape)+axis1
1704
+ def swapaxes(self, x, axis1, axis2):
1705
+ shape = list(x.shape)
1706
+ if axis1 < 0:
1707
+ laxis1 = len(shape) + axis1
911
1708
  else:
912
- laxis1=axis1
913
- if axis2<0:
914
- laxis2=len(shape)+axis2
1709
+ laxis1 = axis1
1710
+ if axis2 < 0:
1711
+ laxis2 = len(shape) + axis2
915
1712
  else:
916
- laxis2=axis2
917
-
918
- naxes=len(shape)
919
- thelist=[i for i in range(naxes)]
920
- thelist[laxis1]=laxis2
921
- thelist[laxis2]=laxis1
922
- return self.backend.bk_transpose(x,thelist)
923
-
1713
+ laxis2 = axis2
1714
+
1715
+ naxes = len(shape)
1716
+ thelist = [i for i in range(naxes)]
1717
+ thelist[laxis1] = laxis2
1718
+ thelist[laxis2] = laxis1
1719
+ return self.backend.bk_transpose(x, thelist)
1720
+
924
1721
  # ---------------------------------------------−---------
925
1722
  # Mean using mask x [....,Npix,....], mask[Nmask,Npix] to [....,Nmask,....]
926
1723
  # if use_2D
927
1724
  # Mean using mask x [....,12,Nside+2*off,Nside+2*off,....], mask[Nmask,12,Nside+2*off,Nside+2*off] to [....,Nmask,....]
928
- def masked_mean(self,x,mask,axis=0,rank=0,calc_var=False):
929
-
930
- #==========================================================================
1725
+ def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
1726
+
1727
+ # ==========================================================================
931
1728
  # in input data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]]
932
1729
  # in input mask=[Nmask,X[,Y]]
933
1730
  # if self.use_2D : X[,Y]] = [X,Y]
934
1731
  # if second level: NORIENT[,NORIENT]= NORIENT,NORIENT
935
- #==========================================================================
936
-
937
- shape=list(x.shape)
938
-
1732
+ # ==========================================================================
1733
+
1734
+ shape = list(x.shape)
1735
+
939
1736
  if not self.use_2D:
940
- nside=int(np.sqrt(x.shape[axis]//12))
941
-
942
- l_mask=mask
1737
+ nside = int(np.sqrt(x.shape[axis] // 12))
1738
+
1739
+ l_mask = mask
943
1740
  if self.mask_norm:
944
- sum_mask=self.backend.bk_reduce_sum(self.backend.bk_reshape(l_mask,[l_mask.shape[0],np.prod(np.array(l_mask.shape[1:]))]),1)
1741
+ sum_mask = self.backend.bk_reduce_sum(
1742
+ self.backend.bk_reshape(
1743
+ l_mask, [l_mask.shape[0], np.prod(np.array(l_mask.shape[1:]))]
1744
+ ),
1745
+ 1,
1746
+ )
945
1747
  if not self.use_2D:
946
- l_mask=12*nside*nside*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
947
- else:
948
- l_mask=mask.shape[1]*mask.shape[2]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
949
-
950
- if self.use_2D:
951
- if self.padding=='VALID' and shape[axis]!=l_mask.shape[1]:
952
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
953
- if shape[axis]!=l_mask.shape[1]:
954
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1748
+ l_mask = (
1749
+ 12
1750
+ * nside
1751
+ * nside
1752
+ * l_mask
1753
+ / self.backend.bk_reshape(
1754
+ sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]]
1755
+ )
1756
+ )
1757
+ elif self.use_2D:
1758
+ l_mask = (
1759
+ mask.shape[1]
1760
+ * mask.shape[2]
1761
+ * l_mask
1762
+ / self.backend.bk_reshape(
1763
+ sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]]
1764
+ )
1765
+ )
955
1766
  else:
956
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
957
-
958
- # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,...,KERNELSZ//2:-self.KERNELSZ//2,KERNELSZ//2:-self.KERNELSZ//2,NORIENT[,NORIENT]]
1767
+ l_mask = (
1768
+ mask.shape[1]
1769
+ * l_mask
1770
+ / self.backend.bk_reshape(
1771
+ sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]]
1772
+ )
1773
+ )
1774
+
959
1775
  if self.use_2D:
960
- ichannel=1
1776
+ if self.padding == "VALID":
1777
+ l_mask = l_mask[
1778
+ :,
1779
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1780
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1781
+ ]
1782
+ if shape[axis] != l_mask.shape[1]:
1783
+ l_mask = l_mask[
1784
+ :,
1785
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1786
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1787
+ ]
1788
+
1789
+ ichannel = 1
1790
+ for i in range(axis):
1791
+ ichannel *= shape[i]
1792
+ ochannel = 1
1793
+ for i in range(axis + 2, len(shape)):
1794
+ ochannel *= shape[i]
1795
+ l_x = self.backend.bk_reshape(
1796
+ x, [ichannel, 1, shape[axis], shape[axis + 1], ochannel]
1797
+ )
1798
+
1799
+ if self.padding == "VALID":
1800
+ oshape = [k for k in shape]
1801
+ oshape[axis] = oshape[axis] - self.KERNELSZ + 1
1802
+ oshape[axis + 1] = oshape[axis + 1] - self.KERNELSZ + 1
1803
+ l_x = self.backend.bk_reshape(
1804
+ l_x[
1805
+ :,
1806
+ :,
1807
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1808
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1809
+ :,
1810
+ ],
1811
+ oshape,
1812
+ )
1813
+
1814
+ elif self.use_1D:
1815
+ if self.padding == "VALID":
1816
+ l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1]
1817
+ if shape[axis] != l_mask.shape[1]:
1818
+ l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1]
1819
+
1820
+ ichannel = 1
961
1821
  for i in range(axis):
962
- ichannel*=shape[i]
963
- ochannel=1
964
- for i in range(axis+2,len(shape)):
965
- ochannel*=shape[i]
966
- l_x=self.backend.bk_reshape(x,[ichannel,shape[axis],shape[axis+1],ochannel])
967
- oshape=[k for k in shape]
968
- oshape[axis]=oshape[axis]-self.KERNELSZ+1
969
- oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
970
- l_x=self.backend.bk_reshape(l_x[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1822
+ ichannel *= shape[i]
1823
+ ochannel = 1
1824
+ for i in range(axis + 1, len(shape)):
1825
+ ochannel *= shape[i]
1826
+ l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1827
+
1828
+ if self.padding == "VALID":
1829
+ oshape = [k for k in shape]
1830
+ oshape[axis] = oshape[axis] - self.KERNELSZ + 1
1831
+ l_x = self.backend.bk_reshape(
1832
+ l_x[:, :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, :], oshape
1833
+ )
971
1834
  else:
972
- l_x=x
973
-
1835
+ ichannel = 1
1836
+ for i in range(axis):
1837
+ ichannel *= shape[i]
1838
+ ochannel = 1
1839
+ for i in range(axis + 1, len(shape)):
1840
+ ochannel *= shape[i]
1841
+ l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1842
+
974
1843
  # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,1,...,X[,Y],NORIENT[,NORIENT]]
975
- l_x=self.backend.bk_expand_dims(l_x,1)
976
-
977
1844
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,X[,Y]]
978
- l_mask=self.backend.bk_expand_dims(l_mask,0)
979
-
980
- # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
981
- for i in range(1,axis):
982
- l_mask=self.backend.bk_expand_dims(l_mask,axis)
983
-
984
- if l_x.dtype==self.all_cbk_type:
985
- l_mask=self.backend.bk_complex(l_mask,self.backend.bk_cast(0.0*l_mask))
986
-
1845
+ l_mask = self.backend.bk_expand_dims(l_mask, 0)
1846
+ # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,X[,Y],1]
1847
+ l_mask = self.backend.bk_expand_dims(l_mask, -1)
1848
+
1849
+ if l_x.dtype == self.all_cbk_type:
1850
+ l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
1851
+
987
1852
  if self.use_2D:
1853
+ mtmp = l_mask
1854
+ vtmp = l_x
1855
+
1856
+ v1 = self.backend.bk_reduce_sum(
1857
+ self.backend.bk_reduce_sum(mtmp * vtmp, axis=2), 2
1858
+ )
1859
+ v2 = self.backend.bk_reduce_sum(
1860
+ self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2), 2
1861
+ )
1862
+ vh = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp, axis=2), 2)
1863
+
1864
+ res = v1 / vh
1865
+
1866
+ oshape = []
1867
+ if axis > 0:
1868
+ oshape = oshape + list(x.shape[0:axis])
1869
+ oshape = oshape + [mask.shape[0]]
1870
+ if axis + 1 < len(x.shape):
1871
+ oshape = oshape + list(x.shape[axis + 2 :])
1872
+
1873
+ if calc_var:
1874
+ if self.backend.bk_is_complex(vtmp):
1875
+ res2 = self.backend.bk_sqrt(
1876
+ (
1877
+ (
1878
+ self.backend.bk_real(v2) / self.backend.bk_real(vh)
1879
+ - self.backend.bk_real(res) * self.backend.bk_real(res)
1880
+ )
1881
+ + (
1882
+ self.backend.bk_imag(v2) / self.backend.bk_real(vh)
1883
+ - self.backend.bk_imag(res) * self.backend.bk_imag(res)
1884
+ )
1885
+ )
1886
+ / self.backend.bk_real(vh)
1887
+ )
1888
+ else:
1889
+ res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1890
+
1891
+ res = self.backend.bk_reshape(res, oshape)
1892
+ res2 = self.backend.bk_reshape(res2, oshape)
1893
+ return res, res2
1894
+ else:
1895
+ res = self.backend.bk_reshape(res, oshape)
1896
+ return res
988
1897
 
989
- # mask=[1,Nmask,....,X,Y] => mask=[1,Nmask,....,X,Y,....]
990
- for i in range(axis+2,len(x.shape)):
991
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1898
+ elif self.use_1D:
1899
+ mtmp = l_mask
1900
+ vtmp = l_x
992
1901
 
993
- shape1=list(l_mask.shape)
994
- shape2=list(l_x.shape)
1902
+ v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=2)
1903
+ v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2)
1904
+ vh = self.backend.bk_reduce_sum(mtmp, axis=2)
995
1905
 
996
- oshape1=shape1[0:axis+1]+[shape1[axis+1]*shape1[axis+2]]+shape1[axis+3:]
997
- oshape2=shape2[0:axis+1]+[shape2[axis+1]*shape2[axis+2]]+shape2[axis+3:]
998
-
999
- mtmp=self.backend.bk_reshape(l_mask,oshape1)
1000
- vtmp=self.backend.bk_reshape(l_x,oshape2)
1906
+ res = v1 / vh
1001
1907
 
1002
- v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=axis+1)
1003
- v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=axis+1)
1004
- vh=self.backend.bk_reduce_sum(mtmp,axis=axis+1)
1908
+ oshape = []
1909
+ if axis > 0:
1910
+ oshape = oshape + list(x.shape[0:axis])
1911
+ oshape = oshape + [mask.shape[0]]
1912
+ if axis + 1 < len(x.shape):
1913
+ oshape = oshape + list(x.shape[axis + 1 :])
1005
1914
 
1006
- res=v1/vh
1007
1915
  if calc_var:
1008
- if vtmp.dtype=='complex128' or vtmp.dtype=='complex64':
1009
- res2=self.backend.bk_complex(self.backend.bk_sqrt(self.backend.bk_real(v2)/self.backend.bk_real(vh)
1010
- -self.backend.bk_real(res)*self.backend.bk_real(res)), \
1011
- self.backend.bk_sqrt(self.backend.bk_imag(v2)/self.backend.bk_real(vh)
1012
- -self.backend.bk_imag(res)*self.backend.bk_imag(res)))
1916
+ if self.backend.bk_is_complex(vtmp):
1917
+ res2 = self.backend.bk_sqrt(
1918
+ (
1919
+ (
1920
+ self.backend.bk_real(v2) / self.backend.bk_real(vh)
1921
+ - self.backend.bk_real(res) * self.backend.bk_real(res)
1922
+ )
1923
+ + (
1924
+ self.backend.bk_imag(v2) / self.backend.bk_real(vh)
1925
+ - self.backend.bk_imag(res) * self.backend.bk_imag(res)
1926
+ )
1927
+ )
1928
+ / self.backend.bk_real(vh)
1929
+ )
1013
1930
  else:
1014
- res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1015
- return res,res2
1931
+ res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1932
+
1933
+ res = self.backend.bk_reshape(res, oshape)
1934
+ res2 = self.backend.bk_reshape(res2, oshape)
1935
+ return res, res2
1016
1936
  else:
1937
+ res = self.backend.bk_reshape(res, oshape)
1017
1938
  return res
1939
+
1018
1940
  else:
1019
- # mask=[1,Nmask,....,X] => mask=[1,Nmask,....,X,....]
1020
- for i in range(axis+1,len(x.shape)):
1021
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1022
-
1023
- v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=axis+1)
1024
- v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=axis+1)
1025
- vh=self.backend.bk_reduce_sum(l_mask,axis=axis+1)
1026
-
1027
- res=v1/vh
1941
+ v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=2)
1942
+ v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=2)
1943
+ vh = self.backend.bk_reduce_sum(l_mask, axis=2)
1944
+
1945
+ res = v1 / vh
1946
+
1947
+ oshape = []
1948
+ if axis > 0:
1949
+ oshape = oshape + list(x.shape[0:axis])
1950
+ oshape = oshape + [mask.shape[0]]
1951
+ if axis + 1 < len(x.shape):
1952
+ oshape = oshape + list(x.shape[axis + 1 :])
1953
+
1028
1954
  if calc_var:
1029
- if l_x.dtype=='complex128' or l_x.dtype=='complex64':
1030
- res2=self.backend.bk_complex(self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1031
- -self.backend.bk_real(res)*self.backend.bk_real(res))/self.backend.bk_real(v2)), \
1032
- self.backend.bk_sqrt((self.backend.bk_imag(v2)/self.backend.bk_real(vh)
1033
- -self.backend.bk_imag(res)*self.backend.bk_imag(res))/self.backend.bk_real(v2)))
1955
+ if self.backend.bk_is_complex(l_x):
1956
+ res2 = self.backend.bk_sqrt(
1957
+ (
1958
+ self.backend.bk_real(v2) / self.backend.bk_real(vh)
1959
+ - self.backend.bk_real(res) * self.backend.bk_real(res)
1960
+ + self.backend.bk_imag(v2) / self.backend.bk_real(vh)
1961
+ - self.backend.bk_imag(res) * self.backend.bk_imag(res)
1962
+ )
1963
+ / self.backend.bk_real(vh)
1964
+ )
1034
1965
  else:
1035
- res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1036
- return res,res2
1966
+ res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1967
+
1968
+ res = self.backend.bk_reshape(res, oshape)
1969
+ res2 = self.backend.bk_reshape(res2, oshape)
1970
+ return res, res2
1037
1971
  else:
1972
+ res = self.backend.bk_reshape(res, oshape)
1038
1973
  return res
1039
-
1974
+
1040
1975
  # ---------------------------------------------−---------
1041
1976
  # convert tensor x [....,a,b,....] to [....,a*b,....]
1042
- def reduce_dim(self,x,axis=0):
1043
- shape=list(x.shape)
1044
-
1045
- if axis<0:
1046
- laxis=len(shape)+axis
1977
+ def reduce_dim(self, x, axis=0):
1978
+ shape = list(x.shape)
1979
+
1980
+ if axis < 0:
1981
+ laxis = len(shape) + axis
1047
1982
  else:
1048
- laxis=axis
1049
-
1050
- if laxis>0 :
1051
- oshape=shape[0:laxis]
1052
- oshape.append(shape[laxis]*shape[laxis+1])
1983
+ laxis = axis
1984
+
1985
+ if laxis > 0:
1986
+ oshape = shape[0:laxis]
1987
+ oshape.append(shape[laxis] * shape[laxis + 1])
1053
1988
  else:
1054
- oshape=[shape[laxis]*shape[laxis+1]]
1055
-
1056
- if laxis<len(shape)-1:
1057
- oshape.extend(shape[laxis+2:])
1058
-
1059
- return(self.backend.bk_reshape(x,oshape))
1060
-
1061
-
1989
+ oshape = [shape[laxis] * shape[laxis + 1]]
1990
+
1991
+ if laxis < len(shape) - 1:
1992
+ oshape.extend(shape[laxis + 2 :])
1993
+
1994
+ return self.backend.bk_reshape(x, oshape)
1995
+
1062
1996
  # ---------------------------------------------−---------
1063
- def conv2d(self,image,ww,axis=0):
1997
+ def conv2d(self, image, ww, axis=0):
1064
1998
 
1065
- if len(ww.shape)==2:
1066
- norient=ww.shape[1]
1999
+ if len(ww.shape) == 2:
2000
+ norient = ww.shape[1]
1067
2001
  else:
1068
- norient=ww.shape[2]
2002
+ norient = ww.shape[2]
1069
2003
 
1070
- shape=image.shape
2004
+ shape = image.shape
1071
2005
 
1072
- if axis>0:
1073
- o_shape=shape[0]
1074
- for k in range(1,axis+1):
1075
- o_shape=o_shape*shape[k]
2006
+ if axis > 0:
2007
+ o_shape = shape[0]
2008
+ for k in range(1, axis + 1):
2009
+ o_shape = o_shape * shape[k]
1076
2010
  else:
1077
- o_shape=image.shape[0]
1078
-
1079
- if len(shape)>axis+3:
1080
- ishape=shape[axis+3]
1081
- for k in range(axis+4,len(shape)):
1082
- ishape=ishape*shape[k]
1083
-
1084
- oshape=[o_shape,shape[axis+1],shape[axis+2],ishape]
1085
-
1086
- #l_image=self.swapaxes(self.bk_reshape(image,oshape),-1,-3)
1087
- l_image=self.backend.bk_reshape(image,oshape)
1088
-
1089
- l_ww=np.zeros([self.KERNELSZ,self.KERNELSZ,ishape,ishape*norient])
2011
+ o_shape = image.shape[0]
2012
+
2013
+ if len(shape) > axis + 3:
2014
+ ishape = shape[axis + 3]
2015
+ for k in range(axis + 4, len(shape)):
2016
+ ishape = ishape * shape[k]
2017
+
2018
+ oshape = [o_shape, shape[axis + 1], shape[axis + 2], ishape]
2019
+
2020
+ # l_image=self.swapaxes(self.bk_reshape(image,oshape),-1,-3)
2021
+ l_image = self.backend.bk_reshape(image, oshape)
2022
+
2023
+ l_ww = np.zeros([self.KERNELSZ, self.KERNELSZ, ishape, ishape * norient])
1090
2024
  for k in range(ishape):
1091
- l_ww[:,:,k,k*norient:(k+1)*norient]=ww.reshape(self.KERNELSZ,self.KERNELSZ,norient)
1092
-
1093
- if l_image.dtype=='complex128' or l_image.dtype=='complex64':
1094
- r=self.backend.conv2d(self.backend.bk_real(l_image),
1095
- l_ww,
1096
- strides=[1, 1, 1, 1],
1097
- padding=self.padding)
1098
- i=self.backend.conv2d(self.backend.bk_imag(l_image),
1099
- l_ww,
1100
- strides=[1, 1, 1, 1],
1101
- padding=self.padding)
1102
- res=self.backend.bk_complex(r,i)
2025
+ l_ww[:, :, k, k * norient : (k + 1) * norient] = ww.reshape(
2026
+ self.KERNELSZ, self.KERNELSZ, norient
2027
+ )
2028
+
2029
+ if self.backend.bk_is_complex(l_image):
2030
+ r = self.backend.conv2d(
2031
+ self.backend.bk_real(l_image),
2032
+ l_ww,
2033
+ strides=[1, 1, 1, 1],
2034
+ padding=self.padding,
2035
+ )
2036
+ i = self.backend.conv2d(
2037
+ self.backend.bk_imag(l_image),
2038
+ l_ww,
2039
+ strides=[1, 1, 1, 1],
2040
+ padding=self.padding,
2041
+ )
2042
+ res = self.backend.bk_complex(r, i)
1103
2043
  else:
1104
- res=self.backend.conv2d(l_image,l_ww,strides=[1, 1, 1, 1],padding=self.padding)
2044
+ res = self.backend.conv2d(
2045
+ l_image, l_ww, strides=[1, 1, 1, 1], padding=self.padding
2046
+ )
1105
2047
 
1106
- res=self.backend.bk_reshape(res,[o_shape,shape[axis+1],shape[axis+2],ishape,norient])
2048
+ res = self.backend.bk_reshape(
2049
+ res, [o_shape, shape[axis + 1], shape[axis + 2], ishape, norient]
2050
+ )
2051
+ else:
2052
+ oshape = [o_shape, shape[axis + 1], shape[axis + 2], 1]
2053
+ l_ww = self.backend.bk_reshape(
2054
+ ww, [self.KERNELSZ, self.KERNELSZ, 1, norient]
2055
+ )
2056
+
2057
+ tmp = self.backend.bk_reshape(image, oshape)
2058
+ if self.backend.bk_is_complex(tmp):
2059
+ r = self.backend.conv2d(
2060
+ self.backend.bk_real(tmp),
2061
+ l_ww,
2062
+ strides=[1, 1, 1, 1],
2063
+ padding=self.padding,
2064
+ )
2065
+ i = self.backend.conv2d(
2066
+ self.backend.bk_imag(tmp),
2067
+ l_ww,
2068
+ strides=[1, 1, 1, 1],
2069
+ padding=self.padding,
2070
+ )
2071
+ res = self.backend.bk_complex(r, i)
2072
+ else:
2073
+ res = self.backend.conv2d(
2074
+ tmp, l_ww, strides=[1, 1, 1, 1], padding=self.padding
2075
+ )
2076
+
2077
+ return self.backend.bk_reshape(res, shape + [norient])
2078
+
2079
+ def diff_data(self, x, y, is_complex=True, sigma=None):
2080
+ if sigma is None:
2081
+ if self.backend.bk_is_complex(x):
2082
+ r = self.backend.bk_square(
2083
+ self.backend.bk_real(x) - self.backend.bk_real(y)
2084
+ )
2085
+ i = self.backend.bk_square(
2086
+ self.backend.bk_imag(x) - self.backend.bk_imag(y)
2087
+ )
2088
+ return self.backend.bk_reduce_sum(r + i)
2089
+ else:
2090
+ r = self.backend.bk_square(x - y)
2091
+ return self.backend.bk_reduce_sum(r)
1107
2092
  else:
1108
- oshape=[o_shape,shape[axis+1],shape[axis+2],1]
1109
- l_ww=self.backend.bk_reshape(ww,[self.KERNELSZ,self.KERNELSZ,1,norient])
1110
-
1111
- tmp=self.backend.bk_reshape(image,oshape)
1112
- if tmp.dtype=='complex128' or tmp.dtype=='complex64':
1113
- r=self.backend.conv2d(self.backend.bk_real(tmp),
1114
- l_ww,
1115
- strides=[1, 1, 1, 1],
1116
- padding=self.padding)
1117
- i=self.backend.conv2d(self.backend.bk_imag(tmp),
1118
- l_ww,
1119
- strides=[1, 1, 1, 1],
1120
- padding=self.padding)
1121
- res=self.backend.bk_complex(r,i)
2093
+ if self.backend.bk_is_complex(x):
2094
+ r = self.backend.bk_square(
2095
+ (self.backend.bk_real(x) - self.backend.bk_real(y)) / sigma
2096
+ )
2097
+ i = self.backend.bk_square(
2098
+ (self.backend.bk_imag(x) - self.backend.bk_imag(y)) / sigma
2099
+ )
2100
+ return self.backend.bk_reduce_sum(r + i)
1122
2101
  else:
1123
- res=self.backend.conv2d(tmp,
1124
- l_ww,
1125
- strides=[1, 1, 1, 1],
1126
- padding=self.padding)
2102
+ r = self.backend.bk_square((x - y) / sigma)
2103
+ return self.backend.bk_reduce_sum(r)
1127
2104
 
1128
- return self.backend.bk_reshape(res,shape+[norient])
1129
-
1130
2105
  # ---------------------------------------------−---------
1131
- def convol(self,in_image,axis=0):
2106
+ def convol(self, in_image, axis=0):
2107
+
2108
+ image = self.backend.bk_cast(in_image)
1132
2109
 
1133
- image=self.backend.bk_cast(in_image)
1134
-
1135
2110
  if self.use_2D:
1136
-
1137
- ishape=list(in_image.shape)
1138
- if len(ishape)<axis+2:
1139
- print('Use of 2D scat with data that has less than 2D')
1140
- exit(0)
1141
-
1142
- npix=ishape[axis]
1143
- npiy=ishape[axis+1]
1144
- odata=1
1145
- if len(ishape)>axis+2:
1146
- for k in range(axis+2,len(ishape)):
1147
- odata=odata*ishape[k]
1148
-
1149
- ndata=1
2111
+ ishape = list(in_image.shape)
2112
+ if len(ishape) < axis + 2:
2113
+ if not self.silent:
2114
+ print("Use of 2D scat with data that has less than 2D")
2115
+ return None
2116
+
2117
+ npix = ishape[axis]
2118
+ npiy = ishape[axis + 1]
2119
+ odata = 1
2120
+ if len(ishape) > axis + 2:
2121
+ for k in range(axis + 2, len(ishape)):
2122
+ odata = odata * ishape[k]
2123
+
2124
+ ndata = 1
1150
2125
  for k in range(axis):
1151
- ndata=ndata*ishape[k]
1152
-
1153
- tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,npiy,odata])
1154
-
1155
- if tim.dtype=='complex128' or tim.dtype=='complex64':
1156
- rr1=self.backend.conv2d(self.backend.bk_real(tim),self.ww_RealT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1157
- ii1=self.backend.conv2d(self.backend.bk_real(tim),self.ww_ImagT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1158
- rr2=self.backend.conv2d(self.backend.bk_imag(tim),self.ww_RealT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1159
- ii2=self.backend.conv2d(self.backend.bk_imag(tim),self.ww_ImagT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1160
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2126
+ ndata = ndata * ishape[k]
2127
+
2128
+ tim = self.backend.bk_reshape(
2129
+ self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2130
+ )
2131
+
2132
+ if self.backend.bk_is_complex(tim):
2133
+ rr1 = self.backend.conv2d(
2134
+ self.backend.bk_real(tim),
2135
+ self.ww_RealT[odata],
2136
+ strides=[1, 1, 1, 1],
2137
+ padding=self.padding,
2138
+ )
2139
+ ii1 = self.backend.conv2d(
2140
+ self.backend.bk_real(tim),
2141
+ self.ww_ImagT[odata],
2142
+ strides=[1, 1, 1, 1],
2143
+ padding=self.padding,
2144
+ )
2145
+ rr2 = self.backend.conv2d(
2146
+ self.backend.bk_imag(tim),
2147
+ self.ww_RealT[odata],
2148
+ strides=[1, 1, 1, 1],
2149
+ padding=self.padding,
2150
+ )
2151
+ ii2 = self.backend.conv2d(
2152
+ self.backend.bk_imag(tim),
2153
+ self.ww_ImagT[odata],
2154
+ strides=[1, 1, 1, 1],
2155
+ padding=self.padding,
2156
+ )
2157
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1161
2158
  else:
1162
- rr=self.backend.conv2d(tim,self.ww_RealT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1163
- ii=self.backend.conv2d(tim,self.ww_ImagT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1164
- res=self.backend.bk_complex(rr,ii)
1165
-
1166
- if axis==0:
1167
- if len(ishape)==2:
1168
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2],self.NORIENT])
2159
+ rr = self.backend.conv2d(
2160
+ tim,
2161
+ self.ww_RealT[odata],
2162
+ strides=[1, 1, 1, 1],
2163
+ padding=self.padding,
2164
+ )
2165
+ ii = self.backend.conv2d(
2166
+ tim,
2167
+ self.ww_ImagT[odata],
2168
+ strides=[1, 1, 1, 1],
2169
+ padding=self.padding,
2170
+ )
2171
+ res = self.backend.bk_complex(rr, ii)
2172
+
2173
+ if axis == 0:
2174
+ if len(ishape) == 2:
2175
+ return self.backend.bk_reshape(
2176
+ res, [res.shape[1], res.shape[2], self.NORIENT]
2177
+ )
1169
2178
  else:
1170
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
2179
+ return self.backend.bk_reshape(
2180
+ res,
2181
+ [res.shape[1], res.shape[2], self.NORIENT] + ishape[axis + 2 :],
2182
+ )
1171
2183
  else:
1172
- if len(ishape)==axis+2:
1173
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT])
2184
+ if len(ishape) == axis + 2:
2185
+ return self.backend.bk_reshape(
2186
+ res, ishape[0:axis] + [res.shape[1], res.shape[2], self.NORIENT]
2187
+ )
1174
2188
  else:
1175
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
1176
-
1177
- return self.backend.bk_reshape(res,[nout,nouty])
1178
-
2189
+ return self.backend.bk_reshape(
2190
+ res,
2191
+ ishape[0:axis]
2192
+ + [res.shape[1], res.shape[2], self.NORIENT]
2193
+ + ishape[axis + 2 :],
2194
+ )
2195
+
2196
+ return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2197
+ elif self.use_1D:
2198
+ ishape = list(in_image.shape)
2199
+ if len(ishape) < axis + 1:
2200
+ if not self.silent:
2201
+ print("Use of 1D scat with data that has less than 1D")
2202
+ return None
2203
+
2204
+ npix = ishape[axis]
2205
+ odata = 1
2206
+ if len(ishape) > axis + 1:
2207
+ for k in range(axis + 1, len(ishape)):
2208
+ odata = odata * ishape[k]
2209
+
2210
+ ndata = 1
2211
+ for k in range(axis):
2212
+ ndata = ndata * ishape[k]
2213
+
2214
+ tim = self.backend.bk_reshape(
2215
+ self.backend.bk_cast(in_image), [ndata, npix, odata]
2216
+ )
2217
+
2218
+ if self.backend.bk_is_complex(tim):
2219
+ rr1 = self.backend.conv1d(
2220
+ self.backend.bk_real(tim),
2221
+ self.ww_RealT[odata],
2222
+ strides=[1, 1, 1],
2223
+ padding=self.padding,
2224
+ )
2225
+ ii1 = self.backend.conv1d(
2226
+ self.backend.bk_real(tim),
2227
+ self.ww_ImagT[odata],
2228
+ strides=[1, 1, 1],
2229
+ padding=self.padding,
2230
+ )
2231
+ rr2 = self.backend.conv1d(
2232
+ self.backend.bk_imag(tim),
2233
+ self.ww_RealT[odata],
2234
+ strides=[1, 1, 1],
2235
+ padding=self.padding,
2236
+ )
2237
+ ii2 = self.backend.conv1d(
2238
+ self.backend.bk_imag(tim),
2239
+ self.ww_ImagT[odata],
2240
+ strides=[1, 1, 1],
2241
+ padding=self.padding,
2242
+ )
2243
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2244
+ else:
2245
+ rr = self.backend.conv1d(
2246
+ tim, self.ww_RealT[odata], strides=[1, 1, 1], padding=self.padding
2247
+ )
2248
+ ii = self.backend.conv1d(
2249
+ tim, self.ww_ImagT[odata], strides=[1, 1, 1], padding=self.padding
2250
+ )
2251
+ res = self.backend.bk_complex(rr, ii)
2252
+
2253
+ if axis == 0:
2254
+ if len(ishape) == 1:
2255
+ return self.backend.bk_reshape(res, [res.shape[1]])
2256
+ else:
2257
+ return self.backend.bk_reshape(
2258
+ res, [res.shape[1]] + ishape[axis + 2 :]
2259
+ )
2260
+ else:
2261
+ if len(ishape) == axis + 1:
2262
+ return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
2263
+ else:
2264
+ return self.backend.bk_reshape(
2265
+ res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2266
+ )
2267
+
2268
+ return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2269
+
1179
2270
  else:
1180
- nside=int(np.sqrt(image.shape[axis]//12))
2271
+ nside = int(np.sqrt(image.shape[axis] // 12))
1181
2272
 
1182
2273
  if self.Idx_Neighbours[nside] is None:
1183
2274
  if self.InitWave is None:
1184
- wr,wi,ws,widx=self.init_index(nside)
2275
+ wr, wi, ws, widx = self.init_index(nside)
1185
2276
  else:
1186
- wr,wi,ws,widx=self.InitWave(self,nside)
1187
-
1188
- self.Idx_Neighbours[nside]=1 #self.backend.constant(tmp)
1189
- self.ww_Real[nside]=wr
1190
- self.ww_Imag[nside]=wi
1191
- self.w_smooth[nside]=ws
1192
-
1193
- l_ww_real=self.ww_Real[nside]
1194
- l_ww_imag=self.ww_Imag[nside]
1195
-
1196
- ishape=list(image.shape)
1197
- odata=1
1198
- for k in range(axis+1,len(ishape)):
1199
- odata=odata*ishape[k]
1200
-
1201
- if axis>0:
1202
- ndata=1
2277
+ wr, wi, ws, widx = self.InitWave(self, nside)
2278
+
2279
+ self.Idx_Neighbours[nside] = 1 # self.backend.constant(tmp)
2280
+ self.ww_Real[nside] = wr
2281
+ self.ww_Imag[nside] = wi
2282
+ self.w_smooth[nside] = ws
2283
+
2284
+ l_ww_real = self.ww_Real[nside]
2285
+ l_ww_imag = self.ww_Imag[nside]
2286
+
2287
+ ishape = list(image.shape)
2288
+ odata = 1
2289
+ for k in range(axis + 1, len(ishape)):
2290
+ odata = odata * ishape[k]
2291
+
2292
+ if axis > 0:
2293
+ ndata = 1
1203
2294
  for k in range(axis):
1204
- ndata=ndata*ishape[k]
1205
- tim=self.backend.bk_reshape(self.backend.bk_cast(image),[ndata,12*nside**2,odata])
1206
- if tim.dtype==self.all_cbk_type:
1207
- rr1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_real(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1208
- ii1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_real(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1209
- rr2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_imag(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1210
- ii2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_imag(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1211
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2295
+ ndata = ndata * ishape[k]
2296
+ tim = self.backend.bk_reshape(
2297
+ self.backend.bk_cast(image), [ndata, 12 * nside**2, odata]
2298
+ )
2299
+ if tim.dtype == self.all_cbk_type:
2300
+ rr1 = self.backend.bk_reshape(
2301
+ self.backend.bk_sparse_dense_matmul(
2302
+ l_ww_real, self.backend.bk_real(tim[0])
2303
+ ),
2304
+ [1, 12 * nside**2, self.NORIENT, odata],
2305
+ )
2306
+ ii1 = self.backend.bk_reshape(
2307
+ self.backend.bk_sparse_dense_matmul(
2308
+ l_ww_imag, self.backend.bk_real(tim[0])
2309
+ ),
2310
+ [1, 12 * nside**2, self.NORIENT, odata],
2311
+ )
2312
+ rr2 = self.backend.bk_reshape(
2313
+ self.backend.bk_sparse_dense_matmul(
2314
+ l_ww_real, self.backend.bk_imag(tim[0])
2315
+ ),
2316
+ [1, 12 * nside**2, self.NORIENT, odata],
2317
+ )
2318
+ ii2 = self.backend.bk_reshape(
2319
+ self.backend.bk_sparse_dense_matmul(
2320
+ l_ww_imag, self.backend.bk_imag(tim[0])
2321
+ ),
2322
+ [1, 12 * nside**2, self.NORIENT, odata],
2323
+ )
2324
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1212
2325
  else:
1213
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,tim[0]),[1,12*nside**2,self.NORIENT,odata])
1214
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,tim[0]),[1,12*nside**2,self.NORIENT,odata])
1215
- res=self.backend.bk_complex(rr,ii)
1216
-
1217
- for k in range(1,ndata):
1218
- if tim.dtype==self.all_cbk_type:
1219
- rr1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_real(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1220
- ii1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_real(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1221
- rr2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_imag(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1222
- ii2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_imag(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1223
- res=self.backend.bk_concat([res,self.backend.bk_complex(rr1-ii2,ii1+rr2)],0)
2326
+ rr = self.backend.bk_reshape(
2327
+ self.backend.bk_sparse_dense_matmul(l_ww_real, tim[0]),
2328
+ [1, 12 * nside**2, self.NORIENT, odata],
2329
+ )
2330
+ ii = self.backend.bk_reshape(
2331
+ self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[0]),
2332
+ [1, 12 * nside**2, self.NORIENT, odata],
2333
+ )
2334
+ res = self.backend.bk_complex(rr, ii)
2335
+
2336
+ for k in range(1, ndata):
2337
+ if tim.dtype == self.all_cbk_type:
2338
+ rr1 = self.backend.bk_reshape(
2339
+ self.backend.bk_sparse_dense_matmul(
2340
+ l_ww_real, self.backend.bk_real(tim[k])
2341
+ ),
2342
+ [1, 12 * nside**2, self.NORIENT, odata],
2343
+ )
2344
+ ii1 = self.backend.bk_reshape(
2345
+ self.backend.bk_sparse_dense_matmul(
2346
+ l_ww_imag, self.backend.bk_real(tim[k])
2347
+ ),
2348
+ [1, 12 * nside**2, self.NORIENT, odata],
2349
+ )
2350
+ rr2 = self.backend.bk_reshape(
2351
+ self.backend.bk_sparse_dense_matmul(
2352
+ l_ww_real, self.backend.bk_imag(tim[k])
2353
+ ),
2354
+ [1, 12 * nside**2, self.NORIENT, odata],
2355
+ )
2356
+ ii2 = self.backend.bk_reshape(
2357
+ self.backend.bk_sparse_dense_matmul(
2358
+ l_ww_imag, self.backend.bk_imag(tim[k])
2359
+ ),
2360
+ [1, 12 * nside**2, self.NORIENT, odata],
2361
+ )
2362
+ res = self.backend.bk_concat(
2363
+ [res, self.backend.bk_complex(rr1 - ii2, ii1 + rr2)], 0
2364
+ )
1224
2365
  else:
1225
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,tim[k]),[1,12*nside**2,self.NORIENT,odata])
1226
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,tim[k]),[1,12*nside**2,self.NORIENT,odata])
1227
- res=self.backend.bk_concat([res,self.backend.bk_complex(rr,ii)],0)
1228
-
1229
- if len(ishape)==axis+1:
1230
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2,self.NORIENT])
2366
+ rr = self.backend.bk_reshape(
2367
+ self.backend.bk_sparse_dense_matmul(l_ww_real, tim[k]),
2368
+ [1, 12 * nside**2, self.NORIENT, odata],
2369
+ )
2370
+ ii = self.backend.bk_reshape(
2371
+ self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[k]),
2372
+ [1, 12 * nside**2, self.NORIENT, odata],
2373
+ )
2374
+ res = self.backend.bk_concat(
2375
+ [res, self.backend.bk_complex(rr, ii)], 0
2376
+ )
2377
+
2378
+ if len(ishape) == axis + 1:
2379
+ return self.backend.bk_reshape(
2380
+ res, ishape[0:axis] + [12 * nside**2, self.NORIENT]
2381
+ )
1231
2382
  else:
1232
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2]+ishape[axis+1:]+[self.NORIENT])
1233
-
1234
- if axis==0:
1235
- tim=self.backend.bk_reshape(self.backend.bk_cast(image),[12*nside**2,odata])
1236
- if tim.dtype==self.all_cbk_type:
1237
- rr1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_real(tim)),[12*nside**2,self.NORIENT,odata])
1238
- ii1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_real(tim)),[12*nside**2,self.NORIENT,odata])
1239
- rr2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_imag(tim)),[12*nside**2,self.NORIENT,odata])
1240
- ii2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_imag(tim)),[12*nside**2,self.NORIENT,odata])
1241
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2383
+ return self.backend.bk_reshape(
2384
+ res,
2385
+ ishape[0:axis]
2386
+ + [12 * nside**2]
2387
+ + ishape[axis + 1 :]
2388
+ + [self.NORIENT],
2389
+ )
2390
+
2391
+ if axis == 0:
2392
+ tim = self.backend.bk_reshape(
2393
+ self.backend.bk_cast(image), [12 * nside**2, odata]
2394
+ )
2395
+ if tim.dtype == self.all_cbk_type:
2396
+ rr1 = self.backend.bk_reshape(
2397
+ self.backend.bk_sparse_dense_matmul(
2398
+ l_ww_real, self.backend.bk_real(tim)
2399
+ ),
2400
+ [12 * nside**2, self.NORIENT, odata],
2401
+ )
2402
+ ii1 = self.backend.bk_reshape(
2403
+ self.backend.bk_sparse_dense_matmul(
2404
+ l_ww_imag, self.backend.bk_real(tim)
2405
+ ),
2406
+ [12 * nside**2, self.NORIENT, odata],
2407
+ )
2408
+ rr2 = self.backend.bk_reshape(
2409
+ self.backend.bk_sparse_dense_matmul(
2410
+ l_ww_real, self.backend.bk_imag(tim)
2411
+ ),
2412
+ [12 * nside**2, self.NORIENT, odata],
2413
+ )
2414
+ ii2 = self.backend.bk_reshape(
2415
+ self.backend.bk_sparse_dense_matmul(
2416
+ l_ww_imag, self.backend.bk_imag(tim)
2417
+ ),
2418
+ [12 * nside**2, self.NORIENT, odata],
2419
+ )
2420
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1242
2421
  else:
1243
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,tim),[12*nside**2,self.NORIENT,odata])
1244
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,tim),[12*nside**2,self.NORIENT,odata])
1245
- res=self.backend.bk_complex(rr,ii)
1246
-
1247
- if len(ishape)==1:
1248
- return self.backend.bk_reshape(res,[12*nside**2,self.NORIENT])
2422
+ rr = self.backend.bk_reshape(
2423
+ self.backend.bk_sparse_dense_matmul(l_ww_real, tim),
2424
+ [12 * nside**2, self.NORIENT, odata],
2425
+ )
2426
+ ii = self.backend.bk_reshape(
2427
+ self.backend.bk_sparse_dense_matmul(l_ww_imag, tim),
2428
+ [12 * nside**2, self.NORIENT, odata],
2429
+ )
2430
+ res = self.backend.bk_complex(rr, ii)
2431
+
2432
+ if len(ishape) == 1:
2433
+ return self.backend.bk_reshape(res, [12 * nside**2, self.NORIENT])
1249
2434
  else:
1250
- return self.backend.bk_reshape(res,[12*nside**2]+ishape[axis+1:]+[self.NORIENT])
1251
- return(res)
1252
-
2435
+ return self.backend.bk_reshape(
2436
+ res, [12 * nside**2] + ishape[axis + 1 :] + [self.NORIENT]
2437
+ )
2438
+ return res
1253
2439
 
1254
2440
  # ---------------------------------------------−---------
1255
- def smooth(self,in_image,axis=0):
2441
+ def smooth(self, in_image, axis=0):
2442
+
2443
+ image = self.backend.bk_cast(in_image)
1256
2444
 
1257
- image=self.backend.bk_cast(in_image)
1258
-
1259
2445
  if self.use_2D:
1260
-
1261
- ishape=list(in_image.shape)
1262
- if len(ishape)<axis+2:
1263
- print('Use of 2D scat with data that has less than 2D')
1264
- exit(0)
1265
-
1266
- npix=ishape[axis]
1267
- npiy=ishape[axis+1]
1268
- odata=1
1269
- if len(ishape)>axis+2:
1270
- for k in range(axis+2,len(ishape)):
1271
- odata=odata*ishape[k]
1272
-
1273
- ndata=1
1274
- for k in range(axis):
1275
- ndata=ndata*ishape[k]
1276
2446
 
1277
- tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,npiy,odata])
2447
+ ishape = list(in_image.shape)
2448
+ if len(ishape) < axis + 2:
2449
+ if not self.silent:
2450
+ print("Use of 2D scat with data that has less than 2D")
2451
+ return None
2452
+
2453
+ npix = ishape[axis]
2454
+ npiy = ishape[axis + 1]
2455
+ odata = 1
2456
+ if len(ishape) > axis + 2:
2457
+ for k in range(axis + 2, len(ishape)):
2458
+ odata = odata * ishape[k]
1278
2459
 
1279
- if tim.dtype=='complex128' or tim.dtype=='complex64':
1280
- rr=self.backend.conv2d(self.backend.bk_real(tim),self.ww_SmoothT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1281
- ii=self.backend.conv2d(self.backend.bk_imag(tim),self.ww_SmoothT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1282
- res=self.backend.bk_complex(rr,ii)
2460
+ ndata = 1
2461
+ for k in range(axis):
2462
+ ndata = ndata * ishape[k]
2463
+
2464
+ tim = self.backend.bk_reshape(
2465
+ self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2466
+ )
2467
+
2468
+ if self.backend.bk_is_complex(tim):
2469
+ rr = self.backend.conv2d(
2470
+ self.backend.bk_real(tim),
2471
+ self.ww_SmoothT[odata],
2472
+ strides=[1, 1, 1, 1],
2473
+ padding=self.padding,
2474
+ )
2475
+ ii = self.backend.conv2d(
2476
+ self.backend.bk_imag(tim),
2477
+ self.ww_SmoothT[odata],
2478
+ strides=[1, 1, 1, 1],
2479
+ padding=self.padding,
2480
+ )
2481
+ res = self.backend.bk_complex(rr, ii)
2482
+ else:
2483
+ res = self.backend.conv2d(
2484
+ tim,
2485
+ self.ww_SmoothT[odata],
2486
+ strides=[1, 1, 1, 1],
2487
+ padding=self.padding,
2488
+ )
2489
+
2490
+ if axis == 0:
2491
+ if len(ishape) == 2:
2492
+ return self.backend.bk_reshape(res, [res.shape[1], res.shape[2]])
2493
+ else:
2494
+ return self.backend.bk_reshape(
2495
+ res, [res.shape[1], res.shape[2]] + ishape[axis + 2 :]
2496
+ )
1283
2497
  else:
1284
- res=self.backend.conv2d(tim,self.ww_SmoothT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1285
-
1286
- if axis==0:
1287
- if len(ishape)==2:
1288
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2]])
2498
+ if len(ishape) == axis + 2:
2499
+ return self.backend.bk_reshape(
2500
+ res, ishape[0:axis] + [res.shape[1], res.shape[2]]
2501
+ )
1289
2502
  else:
1290
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2]]+ishape[axis+2:])
2503
+ return self.backend.bk_reshape(
2504
+ res,
2505
+ ishape[0:axis]
2506
+ + [res.shape[1], res.shape[2]]
2507
+ + ishape[axis + 2 :],
2508
+ )
2509
+
2510
+ return self.backend.bk_reshape(res, in_image.shape)
2511
+ elif self.use_1D:
2512
+
2513
+ ishape = list(in_image.shape)
2514
+ if len(ishape) < axis + 1:
2515
+ if not self.silent:
2516
+ print("Use of 1D scat with data that has less than 1D")
2517
+ return None
2518
+
2519
+ npix = ishape[axis]
2520
+ odata = 1
2521
+ if len(ishape) > axis + 1:
2522
+ for k in range(axis + 1, len(ishape)):
2523
+ odata = odata * ishape[k]
2524
+
2525
+ ndata = 1
2526
+ for k in range(axis):
2527
+ ndata = ndata * ishape[k]
2528
+
2529
+ tim = self.backend.bk_reshape(
2530
+ self.backend.bk_cast(in_image), [ndata, npix, odata]
2531
+ )
2532
+
2533
+ if self.backend.bk_is_complex(tim):
2534
+ rr = self.backend.conv1d(
2535
+ self.backend.bk_real(tim),
2536
+ self.ww_SmoothT[odata],
2537
+ strides=[1, 1, 1],
2538
+ padding=self.padding,
2539
+ )
2540
+ ii = self.backend.conv1d(
2541
+ self.backend.bk_imag(tim),
2542
+ self.ww_SmoothT[odata],
2543
+ strides=[1, 1, 1],
2544
+ padding=self.padding,
2545
+ )
2546
+ res = self.backend.bk_complex(rr, ii)
2547
+ else:
2548
+ res = self.backend.conv1d(
2549
+ tim, self.ww_SmoothT[odata], strides=[1, 1, 1], padding=self.padding
2550
+ )
2551
+
2552
+ if axis == 0:
2553
+ if len(ishape) == 1:
2554
+ return self.backend.bk_reshape(res, [res.shape[1]])
2555
+ else:
2556
+ return self.backend.bk_reshape(
2557
+ res, [res.shape[1]] + ishape[axis + 1 :]
2558
+ )
1291
2559
  else:
1292
- if len(ishape)==axis+2:
1293
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]])
2560
+ if len(ishape) == axis + 1:
2561
+ return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
1294
2562
  else:
1295
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]]+ishape[axis+2:])
1296
-
1297
- return self.backend.bk_reshape(res,[nout,nouty])
1298
-
2563
+ return self.backend.bk_reshape(
2564
+ res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2565
+ )
2566
+
2567
+ return self.backend.bk_reshape(res, in_image.shape)
2568
+
1299
2569
  else:
1300
- nside=int(np.sqrt(image.shape[axis]//12))
2570
+ nside = int(np.sqrt(image.shape[axis] // 12))
1301
2571
 
1302
2572
  if self.Idx_Neighbours[nside] is None:
1303
-
2573
+
1304
2574
  if self.InitWave is None:
1305
- wr,wi,ws,widx=self.init_index(nside)
2575
+ wr, wi, ws, widx = self.init_index(nside)
1306
2576
  else:
1307
- wr,wi,ws,widx=self.InitWave(self,nside)
1308
-
1309
- self.Idx_Neighbours[nside]=1
1310
- self.ww_Real[nside]=wr
1311
- self.ww_Imag[nside]=wi
1312
- self.w_smooth[nside]=ws
1313
-
1314
- l_w_smooth=self.w_smooth[nside]
1315
- ishape=list(image.shape)
1316
-
1317
- odata=1
1318
- for k in range(axis+1,len(ishape)):
1319
- odata=odata*ishape[k]
1320
-
1321
- if axis==0:
1322
- tim=self.backend.bk_reshape(image,[12*nside**2,odata])
1323
- if tim.dtype==self.all_cbk_type:
1324
- rr=self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_real(tim))
1325
- ri=self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_imag(tim))
1326
- res=self.backend.bk_complex(rr,ri)
2577
+ wr, wi, ws, widx = self.InitWave(self, nside)
2578
+
2579
+ self.Idx_Neighbours[nside] = 1
2580
+ self.ww_Real[nside] = wr
2581
+ self.ww_Imag[nside] = wi
2582
+ self.w_smooth[nside] = ws
2583
+
2584
+ l_w_smooth = self.w_smooth[nside]
2585
+ ishape = list(image.shape)
2586
+
2587
+ odata = 1
2588
+ for k in range(axis + 1, len(ishape)):
2589
+ odata = odata * ishape[k]
2590
+
2591
+ if axis == 0:
2592
+ tim = self.backend.bk_reshape(image, [12 * nside**2, odata])
2593
+ if tim.dtype == self.all_cbk_type:
2594
+ rr = self.backend.bk_sparse_dense_matmul(
2595
+ l_w_smooth, self.backend.bk_real(tim)
2596
+ )
2597
+ ri = self.backend.bk_sparse_dense_matmul(
2598
+ l_w_smooth, self.backend.bk_imag(tim)
2599
+ )
2600
+ res = self.backend.bk_complex(rr, ri)
1327
2601
  else:
1328
- res=self.backend.bk_sparse_dense_matmul(l_w_smooth,tim)
1329
- if len(ishape)==1:
1330
- return self.backend.bk_reshape(res,[12*nside**2])
2602
+ res = self.backend.bk_sparse_dense_matmul(l_w_smooth, tim)
2603
+ if len(ishape) == 1:
2604
+ return self.backend.bk_reshape(res, [12 * nside**2])
1331
2605
  else:
1332
- return self.backend.bk_reshape(res,[12*nside**2]+ishape[axis+1:])
1333
-
1334
- if axis>0:
1335
- ndata=ishape[0]
1336
- for k in range(1,axis):
1337
- ndata=ndata*ishape[k]
1338
- tim=self.backend.bk_reshape(image,[ndata,12*nside**2,odata])
1339
- if tim.dtype==self.all_cbk_type:
1340
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_real(tim[0])),[1,12*nside**2,odata])
1341
- ri=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_imag(tim[0])),[1,12*nside**2,odata])
1342
- res=self.backend.bk_complex(rr,ri)
2606
+ return self.backend.bk_reshape(
2607
+ res, [12 * nside**2] + ishape[axis + 1 :]
2608
+ )
2609
+
2610
+ if axis > 0:
2611
+ ndata = ishape[0]
2612
+ for k in range(1, axis):
2613
+ ndata = ndata * ishape[k]
2614
+ tim = self.backend.bk_reshape(image, [ndata, 12 * nside**2, odata])
2615
+ if tim.dtype == self.all_cbk_type:
2616
+ rr = self.backend.bk_reshape(
2617
+ self.backend.bk_sparse_dense_matmul(
2618
+ l_w_smooth, self.backend.bk_real(tim[0])
2619
+ ),
2620
+ [1, 12 * nside**2, odata],
2621
+ )
2622
+ ri = self.backend.bk_reshape(
2623
+ self.backend.bk_sparse_dense_matmul(
2624
+ l_w_smooth, self.backend.bk_imag(tim[0])
2625
+ ),
2626
+ [1, 12 * nside**2, odata],
2627
+ )
2628
+ res = self.backend.bk_complex(rr, ri)
1343
2629
  else:
1344
- res=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,tim[0]),[1,12*nside**2,odata])
1345
-
1346
- for k in range(1,ndata):
1347
- if tim.dtype==self.all_cbk_type:
1348
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_real(tim[k])),[1,12*nside**2,odata])
1349
- ri=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_imag(tim[k])),[1,12*nside**2,odata])
1350
- res=self.backend.bk_concat([res,self.backend.bk_complex(rr,ri)],0)
2630
+ res = self.backend.bk_reshape(
2631
+ self.backend.bk_sparse_dense_matmul(l_w_smooth, tim[0]),
2632
+ [1, 12 * nside**2, odata],
2633
+ )
2634
+
2635
+ for k in range(1, ndata):
2636
+ if tim.dtype == self.all_cbk_type:
2637
+ rr = self.backend.bk_reshape(
2638
+ self.backend.bk_sparse_dense_matmul(
2639
+ l_w_smooth, self.backend.bk_real(tim[k])
2640
+ ),
2641
+ [1, 12 * nside**2, odata],
2642
+ )
2643
+ ri = self.backend.bk_reshape(
2644
+ self.backend.bk_sparse_dense_matmul(
2645
+ l_w_smooth, self.backend.bk_imag(tim[k])
2646
+ ),
2647
+ [1, 12 * nside**2, odata],
2648
+ )
2649
+ res = self.backend.bk_concat(
2650
+ [res, self.backend.bk_complex(rr, ri)], 0
2651
+ )
1351
2652
  else:
1352
- res=self.backend.bk_concat([res,self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,tim[k]),[1,12*nside**2,odata])],0)
1353
-
1354
- if len(ishape)==axis+1:
1355
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2])
2653
+ res = self.backend.bk_concat(
2654
+ [
2655
+ res,
2656
+ self.backend.bk_reshape(
2657
+ self.backend.bk_sparse_dense_matmul(
2658
+ l_w_smooth, tim[k]
2659
+ ),
2660
+ [1, 12 * nside**2, odata],
2661
+ ),
2662
+ ],
2663
+ 0,
2664
+ )
2665
+
2666
+ if len(ishape) == axis + 1:
2667
+ return self.backend.bk_reshape(
2668
+ res, ishape[0:axis] + [12 * nside**2]
2669
+ )
1356
2670
  else:
1357
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2]+ishape[axis+1:])
1358
-
1359
-
1360
- return(res)
1361
-
2671
+ return self.backend.bk_reshape(
2672
+ res, ishape[0:axis] + [12 * nside**2] + ishape[axis + 1 :]
2673
+ )
2674
+
2675
+ return res
2676
+
1362
2677
  # ---------------------------------------------−---------
1363
2678
  def get_kernel_size(self):
1364
- return(self.KERNELSZ)
1365
-
2679
+ return self.KERNELSZ
2680
+
1366
2681
  # ---------------------------------------------−---------
1367
2682
  def get_nb_orient(self):
1368
- return(self.NORIENT)
1369
-
2683
+ return self.NORIENT
2684
+
1370
2685
  # ---------------------------------------------−---------
1371
- def get_ww(self,nside=1):
1372
- return(self.ww_Real[nside],self.ww_Imag[nside])
1373
-
2686
+ def get_ww(self, nside=1):
2687
+ return (self.ww_Real[nside], self.ww_Imag[nside])
2688
+
1374
2689
  # ---------------------------------------------−---------
1375
2690
  def plot_ww(self):
1376
- c,s=self.get_ww()
2691
+ c, s = self.get_ww()
1377
2692
  import matplotlib.pyplot as plt
1378
- plt.figure(figsize=(16,6))
1379
- npt=int(np.sqrt(c.shape[0]))
2693
+
2694
+ plt.figure(figsize=(16, 6))
2695
+ npt = int(np.sqrt(c.shape[0]))
1380
2696
  for i in range(c.shape[1]):
1381
- plt.subplot(2,c.shape[1],1+i)
1382
- plt.imshow(c[:,i].reshape(npt,npt),cmap='jet',vmin=-c.max(),vmax=c.max())
1383
- plt.subplot(2,c.shape[1],1+i+c.shape[1])
1384
- plt.imshow(s[:,i].reshape(npt,npt),cmap='jet',vmin=-c.max(),vmax=c.max())
2697
+ plt.subplot(2, c.shape[1], 1 + i)
2698
+ plt.imshow(
2699
+ c[:, i].reshape(npt, npt), cmap="jet", vmin=-c.max(), vmax=c.max()
2700
+ )
2701
+ plt.subplot(2, c.shape[1], 1 + i + c.shape[1])
2702
+ plt.imshow(
2703
+ s[:, i].reshape(npt, npt), cmap="jet", vmin=-c.max(), vmax=c.max()
2704
+ )
1385
2705
  sys.stdout.flush()
1386
2706
  plt.show()
1387
-
1388
-
1389
-
1390
-
1391
-