foscat 3.1.5__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -1,1186 +1,1468 @@
1
- import numpy as np
1
+ import os
2
+ import sys
3
+
2
4
  import healpy as hp
3
- import os, sys
4
- import foscat.backend as bk
5
+ import numpy as np
5
6
  from scipy.interpolate import griddata
6
7
 
8
+ import foscat.backend as bk
9
+
10
+ TMPFILE_VERSION = "V4_0"
7
11
 
8
- TMPFILE_VERSION='V4_0'
9
12
 
10
13
  class FoCUS:
11
- def __init__(self,
12
- NORIENT=4,
13
- LAMBDA=1.2,
14
- KERNELSZ=3,
15
- slope=1.0,
16
- all_type='float64',
17
- nstep_max=16,
18
- padding='SAME',
19
- gpupos=0,
20
- mask_thres=None,
21
- mask_norm=False,
22
- OSTEP=0,
23
- isMPI=False,
24
- TEMPLATE_PATH='data',
25
- BACKEND='tensorflow',
26
- use_2D=False,
27
- use_1D=False,
28
- return_data=False,
29
- JmaxDelta=0,
30
- DODIV=False,
31
- InitWave=None,
32
- silent=False,
33
- mpi_size=1,
34
- mpi_rank=0):
35
-
36
- self.__version__ = '3.1.5'
14
+ def __init__(
15
+ self,
16
+ NORIENT=4,
17
+ LAMBDA=1.2,
18
+ KERNELSZ=3,
19
+ slope=1.0,
20
+ all_type="float64",
21
+ nstep_max=16,
22
+ padding="SAME",
23
+ gpupos=0,
24
+ mask_thres=None,
25
+ mask_norm=False,
26
+ OSTEP=0,
27
+ isMPI=False,
28
+ TEMPLATE_PATH="data",
29
+ BACKEND="tensorflow",
30
+ use_2D=False,
31
+ use_1D=False,
32
+ return_data=False,
33
+ JmaxDelta=0,
34
+ DODIV=False,
35
+ InitWave=None,
36
+ silent=False,
37
+ mpi_size=1,
38
+ mpi_rank=0,
39
+ ):
40
+
41
+ self.__version__ = "3.2.0"
37
42
  # P00 coeff for normalization for scat_cov
38
- self.TMPFILE_VERSION=TMPFILE_VERSION
43
+ self.TMPFILE_VERSION = TMPFILE_VERSION
39
44
  self.P1_dic = None
40
45
  self.P2_dic = None
41
- self.isMPI=isMPI
46
+ self.isMPI = isMPI
42
47
  self.mask_thres = mask_thres
43
48
  self.mask_norm = mask_norm
44
- self.InitWave=InitWave
49
+ self.InitWave = InitWave
45
50
 
46
- self.mpi_size=mpi_size
47
- self.mpi_rank=mpi_rank
48
- self.return_data=return_data
49
- self.silent=silent
51
+ self.mpi_size = mpi_size
52
+ self.mpi_rank = mpi_rank
53
+ self.return_data = return_data
54
+ self.silent = silent
50
55
 
51
56
  if not self.silent:
52
- print('================================================')
53
- print(' START FOSCAT CONFIGURATION')
54
- print('================================================')
57
+ print("================================================")
58
+ print(" START FOSCAT CONFIGURATION")
59
+ print("================================================")
55
60
  sys.stdout.flush()
56
61
 
57
- self.TEMPLATE_PATH=TEMPLATE_PATH
58
- if os.path.exists(self.TEMPLATE_PATH)==False:
62
+ self.TEMPLATE_PATH = TEMPLATE_PATH
63
+ if not os.path.exists(self.TEMPLATE_PATH):
59
64
  if not self.silent:
60
- print('The directory %s to store temporary information for FoCUS does not exist: Try to create it'%(self.TEMPLATE_PATH))
65
+ print(
66
+ "The directory %s to store temporary information for FoCUS does not exist: Try to create it"
67
+ % (self.TEMPLATE_PATH)
68
+ )
61
69
  try:
62
- os.system('mkdir -p %s'%(self.TEMPLATE_PATH))
70
+ os.system("mkdir -p %s" % (self.TEMPLATE_PATH))
63
71
  if not self.silent:
64
- print('The directory %s is created')
72
+ print("The directory %s is created")
65
73
  except:
66
74
  if not self.silent:
67
- print('Impossible to create the directory %s'%(self.TEMPLATE_PATH))
75
+ print(
76
+ "Impossible to create the directory %s" % (self.TEMPLATE_PATH)
77
+ )
68
78
  return None
69
-
70
- self.number_of_loss=0
71
-
72
- self.history=np.zeros([10])
73
- self.nlog=0
74
- self.padding=padding
75
-
76
- if OSTEP!=0:
79
+
80
+ self.number_of_loss = 0
81
+
82
+ self.history = np.zeros([10])
83
+ self.nlog = 0
84
+ self.padding = padding
85
+
86
+ if OSTEP != 0:
77
87
  if not self.silent:
78
- print('OPTION option is deprecated after version 2.0.6. Please use Jmax option')
79
- JmaxDelta=OSTEP
88
+ print(
89
+ "OPTION option is deprecated after version 2.0.6. Please use Jmax option"
90
+ )
91
+ JmaxDelta = OSTEP
80
92
  else:
81
- OSTEP=JmaxDelta
82
-
83
- if JmaxDelta<-1:
93
+ OSTEP = JmaxDelta
94
+
95
+ if JmaxDelta < -1:
84
96
  if not self.silent:
85
- print('Warning : Jmax can not be smaller than -1')
97
+ print("Warning : Jmax can not be smaller than -1")
86
98
  return None
87
-
88
- self.OSTEP=JmaxDelta
89
- self.use_2D=use_2D
90
- self.use_1D=use_1D
91
-
99
+
100
+ self.OSTEP = JmaxDelta
101
+ self.use_2D = use_2D
102
+ self.use_1D = use_1D
103
+
92
104
  if isMPI:
93
105
  from mpi4py import MPI
94
106
 
95
- self.comm= MPI.COMM_WORLD
96
- if all_type=='float32':
97
- self.MPI_ALL_TYPE=MPI.FLOAT
107
+ self.comm = MPI.COMM_WORLD
108
+ if all_type == "float32":
109
+ self.MPI_ALL_TYPE = MPI.FLOAT
98
110
  else:
99
- self.MPI_ALL_TYPE=MPI.DOUBLE
111
+ self.MPI_ALL_TYPE = MPI.DOUBLE
100
112
  else:
101
- self.MPI_ALL_TYPE=None
102
-
103
- self.all_type=all_type
104
- self.BACKEND=BACKEND
105
- self.backend=bk.foscat_backend(BACKEND,
106
- all_type=all_type,
107
- mpi_rank=mpi_rank,
108
- gpupos=gpupos,
109
- silent=self.silent)
110
-
111
- self.all_bk_type=self.backend.all_bk_type
112
- self.all_cbk_type=self.backend.all_cbk_type
113
- self.gpulist=self.backend.gpulist
114
- self.ngpu=self.backend.ngpu
115
- self.rank=mpi_rank
116
-
117
- self.gpupos=(gpupos+mpi_rank)%self.backend.ngpu
113
+ self.MPI_ALL_TYPE = None
114
+
115
+ self.all_type = all_type
116
+ self.BACKEND = BACKEND
117
+ self.backend = bk.foscat_backend(
118
+ BACKEND,
119
+ all_type=all_type,
120
+ mpi_rank=mpi_rank,
121
+ gpupos=gpupos,
122
+ silent=self.silent,
123
+ )
124
+
125
+ self.all_bk_type = self.backend.all_bk_type
126
+ self.all_cbk_type = self.backend.all_cbk_type
127
+ self.gpulist = self.backend.gpulist
128
+ self.ngpu = self.backend.ngpu
129
+ self.rank = mpi_rank
130
+
131
+ self.gpupos = (gpupos + mpi_rank) % self.backend.ngpu
118
132
 
119
133
  if not self.silent:
120
- print('============================================================')
121
- print('== ==')
122
- print('== ==')
123
- print('== RUN ON GPU Rank %d : %s =='%(mpi_rank,self.gpulist[self.gpupos%self.ngpu]))
124
- print('== ==')
125
- print('== ==')
126
- print('============================================================')
134
+ print("============================================================")
135
+ print("== ==")
136
+ print("== ==")
137
+ print(
138
+ "== RUN ON GPU Rank %d : %s =="
139
+ % (mpi_rank, self.gpulist[self.gpupos % self.ngpu])
140
+ )
141
+ print("== ==")
142
+ print("== ==")
143
+ print("============================================================")
127
144
  sys.stdout.flush()
128
145
 
129
- l_NORIENT=NORIENT
146
+ l_NORIENT = NORIENT
130
147
  if DODIV:
131
- l_NORIENT=NORIENT+2
132
-
133
- self.NORIENT=l_NORIENT
134
- self.LAMBDA=LAMBDA
135
- self.slope=slope
136
-
137
- self.R_off=(KERNELSZ-1)//2
138
- if (self.R_off//2)*2<self.R_off:
139
- self.R_off+=1
140
-
141
- self.ww_Real = {}
142
- self.ww_Imag = {}
143
- self.ww_CNN_Transpose = {}
144
- self.ww_CNN = {}
145
- self.X_CNN = {}
146
- self.Y_CNN = {}
147
- self.Z_CNN = {}
148
-
149
- wwc=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
150
- wws=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
151
-
152
- x=np.repeat(np.arange(KERNELSZ)-KERNELSZ//2,KERNELSZ).reshape(KERNELSZ,KERNELSZ)
153
- y=x.T
154
-
155
- if NORIENT==1:
156
- xx=(3/float(KERNELSZ))*LAMBDA*x
157
- yy=(3/float(KERNELSZ))*LAMBDA*y
158
-
159
- if KERNELSZ==5:
160
- #w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
161
- w_smooth=np.exp(-(xx**2+yy**2))
162
- tmp=np.exp(-2*(xx**2+yy**2))-0.25*np.exp(-0.5*(xx**2+yy**2))
148
+ l_NORIENT = NORIENT + 2
149
+
150
+ self.NORIENT = l_NORIENT
151
+ self.LAMBDA = LAMBDA
152
+ self.slope = slope
153
+
154
+ self.R_off = (KERNELSZ - 1) // 2
155
+ if (self.R_off // 2) * 2 < self.R_off:
156
+ self.R_off += 1
157
+
158
+ self.ww_Real = {}
159
+ self.ww_Imag = {}
160
+ self.ww_CNN_Transpose = {}
161
+ self.ww_CNN = {}
162
+ self.X_CNN = {}
163
+ self.Y_CNN = {}
164
+ self.Z_CNN = {}
165
+
166
+ wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
167
+ wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
168
+
169
+ x = np.repeat(np.arange(KERNELSZ) - KERNELSZ // 2, KERNELSZ).reshape(
170
+ KERNELSZ, KERNELSZ
171
+ )
172
+ y = x.T
173
+
174
+ if NORIENT == 1:
175
+ xx = (3 / float(KERNELSZ)) * LAMBDA * x
176
+ yy = (3 / float(KERNELSZ)) * LAMBDA * y
177
+
178
+ if KERNELSZ == 5:
179
+ # w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
180
+ w_smooth = np.exp(-(xx**2 + yy**2))
181
+ tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp(
182
+ -0.5 * (xx**2 + yy**2)
183
+ )
163
184
  else:
164
- w_smooth=np.exp(-0.5*(xx**2+yy**2))
165
- tmp=np.exp(-2*(xx**2+yy**2))-0.25*np.exp(-0.5*(xx**2+yy**2))
166
-
167
- wwc[:,0]=tmp.flatten()-tmp.mean()
168
- tmp=0*w_smooth
169
- wws[:,0]=tmp.flatten()
170
- sigma=np.sqrt((wwc[:,0]**2).mean())
171
- wwc[:,0]/=sigma
172
- wws[:,0]/=sigma
173
-
174
- w_smooth=w_smooth.flatten()
185
+ w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
186
+ tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp(
187
+ -0.5 * (xx**2 + yy**2)
188
+ )
189
+
190
+ wwc[:, 0] = tmp.flatten() - tmp.mean()
191
+ tmp = 0 * w_smooth
192
+ wws[:, 0] = tmp.flatten()
193
+ sigma = np.sqrt((wwc[:, 0] ** 2).mean())
194
+ wwc[:, 0] /= sigma
195
+ wws[:, 0] /= sigma
196
+
197
+ w_smooth = w_smooth.flatten()
175
198
  else:
176
199
  for i in range(NORIENT):
177
- a=i/float(NORIENT)*np.pi
178
- xx=(3/float(KERNELSZ))*LAMBDA*(x*np.cos(a)+y*np.sin(a))
179
- yy=(3/float(KERNELSZ))*LAMBDA*(x*np.sin(a)-y*np.cos(a))
200
+ a = i / float(NORIENT) * np.pi
201
+ xx = (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
202
+ yy = (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
180
203
 
181
- if KERNELSZ==5:
182
- #w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
183
- w_smooth=np.exp(-(xx**2+yy**2))
204
+ if KERNELSZ == 5:
205
+ # w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
206
+ w_smooth = np.exp(-(xx**2 + yy**2))
184
207
  else:
185
- w_smooth=np.exp(-0.5*(xx**2+yy**2))
186
- tmp1=np.cos(yy*np.pi)*w_smooth
187
- tmp2=np.sin(yy*np.pi)*w_smooth
188
-
189
- wwc[:,i]=tmp1.flatten()-tmp1.mean()
190
- wws[:,i]=tmp2.flatten()-tmp2.mean()
191
- sigma=np.sqrt((wwc[:,i]**2).mean())
192
- wwc[:,i]/=sigma
193
- wws[:,i]/=sigma
194
-
195
- if DODIV and i==0:
196
- r=(xx**2+yy**2)
197
- theta=np.arctan2(yy,xx)
198
- theta[KERNELSZ//2,KERNELSZ//2]=0.0
199
- tmp1=r*np.cos(2*theta)*w_smooth
200
- tmp2=r*np.sin(2*theta)*w_smooth
201
-
202
- wwc[:,NORIENT]=tmp1.flatten()-tmp1.mean()
203
- wws[:,NORIENT]=tmp2.flatten()-tmp2.mean()
204
- sigma=np.sqrt((wwc[:,NORIENT]**2).mean())
205
-
206
- wwc[:,NORIENT]/=sigma
207
- wws[:,NORIENT]/=sigma
208
- tmp1=r*np.cos(2*theta+np.pi)
209
- tmp2=r*np.sin(2*theta+np.pi)
210
-
211
- wwc[:,NORIENT+1]=tmp1.flatten()-tmp1.mean()
212
- wws[:,NORIENT+1]=tmp2.flatten()-tmp2.mean()
213
- sigma=np.sqrt((wwc[:,NORIENT+1]**2).mean())
214
- wwc[:,NORIENT+1]/=sigma
215
- wws[:,NORIENT+1]/=sigma
216
-
217
-
218
- w_smooth=w_smooth.flatten()
208
+ w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
209
+ tmp1 = np.cos(yy * np.pi) * w_smooth
210
+ tmp2 = np.sin(yy * np.pi) * w_smooth
211
+
212
+ wwc[:, i] = tmp1.flatten() - tmp1.mean()
213
+ wws[:, i] = tmp2.flatten() - tmp2.mean()
214
+ sigma = np.sqrt((wwc[:, i] ** 2).mean())
215
+ wwc[:, i] /= sigma
216
+ wws[:, i] /= sigma
217
+
218
+ if DODIV and i == 0:
219
+ r = xx**2 + yy**2
220
+ theta = np.arctan2(yy, xx)
221
+ theta[KERNELSZ // 2, KERNELSZ // 2] = 0.0
222
+ tmp1 = r * np.cos(2 * theta) * w_smooth
223
+ tmp2 = r * np.sin(2 * theta) * w_smooth
224
+
225
+ wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
226
+ wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
227
+ sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
228
+
229
+ wwc[:, NORIENT] /= sigma
230
+ wws[:, NORIENT] /= sigma
231
+ tmp1 = r * np.cos(2 * theta + np.pi)
232
+ tmp2 = r * np.sin(2 * theta + np.pi)
233
+
234
+ wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
235
+ wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
236
+ sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
237
+ wwc[:, NORIENT + 1] /= sigma
238
+ wws[:, NORIENT + 1] /= sigma
239
+
240
+ w_smooth = w_smooth.flatten()
219
241
  if self.use_1D:
220
- KERNELSZ=5
221
-
222
- self.KERNELSZ=KERNELSZ
242
+ KERNELSZ = 5
243
+
244
+ self.KERNELSZ = KERNELSZ
245
+
246
+ self.Idx_Neighbours = {}
223
247
 
224
- self.Idx_Neighbours={}
225
-
226
248
  if not self.use_2D and not self.use_1D:
227
249
  self.w_smooth = {}
228
250
  for i in range(nstep_max):
229
- lout=(2**i)
230
- self.ww_Real[lout]=None
251
+ lout = 2**i
252
+ self.ww_Real[lout] = None
231
253
 
232
- for i in range(1,6):
233
- lout=(2**i)
254
+ for i in range(1, 6):
255
+ lout = 2**i
234
256
  if not self.silent:
235
- print('Init Wave ',lout)
236
-
257
+ print("Init Wave ", lout)
258
+
237
259
  if self.InitWave is None:
238
- wr,wi,ws,widx=self.init_index(lout)
260
+ wr, wi, ws, widx = self.init_index(lout)
239
261
  else:
240
- wr,wi,ws,widx=self.InitWave(self,lout)
241
-
242
- self.Idx_Neighbours[lout]=1 #self.backend.constant(widx)
243
- self.ww_Real[lout]=wr
244
- self.ww_Imag[lout]=wi
245
- self.w_smooth[lout]=ws
246
- elif self.use_1D==True:
247
- self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
248
- self.ww_RealT={}
249
- self.ww_ImagT={}
250
- self.ww_SmoothT={}
251
- if KERNELSZ==5:
252
- xx=np.arange(5)-2
253
- w=np.exp(-0.25*(xx)**2)
254
- c=w*np.cos((xx)*np.pi/2)
255
- s=w*np.sin((xx)*np.pi/2)
256
-
257
- w=w/np.sum(w)
258
- c=c-np.mean(c)
259
- s=s-np.mean(s)
260
- r=np.sum(np.sqrt(c*c+s*s))
261
- c=c/r
262
- s=s/r
263
- self.ww_RealT[1]=self.backend.constant(np.array(c).reshape(xx.shape[0],1,1))
264
- self.ww_ImagT[1]=self.backend.constant(np.array(s).reshape(xx.shape[0],1,1))
265
- self.ww_SmoothT[1] = self.backend.constant(np.array(w).reshape(xx.shape[0],1,1))
266
-
267
- else:
268
- self.w_smooth=slope*(w_smooth/w_smooth.sum()).astype(self.all_type)
269
- self.ww_RealT={}
270
- self.ww_ImagT={}
271
- self.ww_SmoothT={}
262
+ wr, wi, ws, widx = self.InitWave(self, lout)
272
263
 
273
- self.ww_SmoothT[1] = self.backend.constant(self.w_smooth.reshape(KERNELSZ,KERNELSZ,1,1))
274
- www=np.zeros([KERNELSZ,KERNELSZ,NORIENT,NORIENT],dtype=self.all_type)
264
+ self.Idx_Neighbours[lout] = 1 # self.backend.constant(widx)
265
+ self.ww_Real[lout] = wr
266
+ self.ww_Imag[lout] = wi
267
+ self.w_smooth[lout] = ws
268
+ elif self.use_1D:
269
+ self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
270
+ self.ww_RealT = {}
271
+ self.ww_ImagT = {}
272
+ self.ww_SmoothT = {}
273
+ if KERNELSZ == 5:
274
+ xx = np.arange(5) - 2
275
+ w = np.exp(-0.25 * (xx) ** 2)
276
+ c = w * np.cos((xx) * np.pi / 2)
277
+ s = w * np.sin((xx) * np.pi / 2)
278
+
279
+ w = w / np.sum(w)
280
+ c = c - np.mean(c)
281
+ s = s - np.mean(s)
282
+ r = np.sum(np.sqrt(c * c + s * s))
283
+ c = c / r
284
+ s = s / r
285
+ self.ww_RealT[1] = self.backend.constant(
286
+ np.array(c).reshape(xx.shape[0], 1, 1)
287
+ )
288
+ self.ww_ImagT[1] = self.backend.constant(
289
+ np.array(s).reshape(xx.shape[0], 1, 1)
290
+ )
291
+ self.ww_SmoothT[1] = self.backend.constant(
292
+ np.array(w).reshape(xx.shape[0], 1, 1)
293
+ )
294
+
295
+ else:
296
+ self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
297
+ self.ww_RealT = {}
298
+ self.ww_ImagT = {}
299
+ self.ww_SmoothT = {}
300
+
301
+ self.ww_SmoothT[1] = self.backend.constant(
302
+ self.w_smooth.reshape(KERNELSZ, KERNELSZ, 1, 1)
303
+ )
304
+ www = np.zeros([KERNELSZ, KERNELSZ, NORIENT, NORIENT], dtype=self.all_type)
275
305
  for k in range(NORIENT):
276
- www[:,:,k,k]=self.w_smooth.reshape(KERNELSZ,KERNELSZ)
277
- self.ww_SmoothT[NORIENT] = self.backend.constant(www.reshape(KERNELSZ,KERNELSZ,NORIENT,NORIENT))
278
- self.ww_RealT[1]=self.backend.constant(self.backend.bk_reshape(wwc.astype(self.all_type),[KERNELSZ,KERNELSZ,1,NORIENT]))
279
- self.ww_ImagT[1]=self.backend.constant(self.backend.bk_reshape(wws.astype(self.all_type),[KERNELSZ,KERNELSZ,1,NORIENT]))
306
+ www[:, :, k, k] = self.w_smooth.reshape(KERNELSZ, KERNELSZ)
307
+ self.ww_SmoothT[NORIENT] = self.backend.constant(
308
+ www.reshape(KERNELSZ, KERNELSZ, NORIENT, NORIENT)
309
+ )
310
+ self.ww_RealT[1] = self.backend.constant(
311
+ self.backend.bk_reshape(
312
+ wwc.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
313
+ )
314
+ )
315
+ self.ww_ImagT[1] = self.backend.constant(
316
+ self.backend.bk_reshape(
317
+ wws.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
318
+ )
319
+ )
320
+
280
321
  def doorientw(x):
281
- y=np.zeros([KERNELSZ,KERNELSZ,NORIENT,NORIENT*NORIENT],dtype=self.all_type)
322
+ y = np.zeros(
323
+ [KERNELSZ, KERNELSZ, NORIENT, NORIENT * NORIENT],
324
+ dtype=self.all_type,
325
+ )
282
326
  for k in range(NORIENT):
283
- y[:,:,k,k*NORIENT:k*NORIENT+NORIENT]=x.reshape(KERNELSZ,KERNELSZ,NORIENT)
327
+ y[:, :, k, k * NORIENT : k * NORIENT + NORIENT] = x.reshape(
328
+ KERNELSZ, KERNELSZ, NORIENT
329
+ )
284
330
  return y
285
- self.ww_RealT[NORIENT]=self.backend.constant(doorientw(wwc.astype(self.all_type)))
286
- self.ww_ImagT[NORIENT]=self.backend.constant(doorientw(wws.astype(self.all_type)))
287
- self.pix_interp_val={}
288
- self.weight_interp_val={}
289
- self.ring2nest={}
290
- self.nest2R={}
291
- self.nest2R1={}
292
- self.nest2R2={}
293
- self.nest2R3={}
294
- self.nest2R4={}
295
- self.inv_nest2R={}
296
- self.remove_border={}
297
-
298
- self.ampnorm={}
299
-
331
+
332
+ self.ww_RealT[NORIENT] = self.backend.constant(
333
+ doorientw(wwc.astype(self.all_type))
334
+ )
335
+ self.ww_ImagT[NORIENT] = self.backend.constant(
336
+ doorientw(wws.astype(self.all_type))
337
+ )
338
+ self.pix_interp_val = {}
339
+ self.weight_interp_val = {}
340
+ self.ring2nest = {}
341
+ self.nest2R = {}
342
+ self.nest2R1 = {}
343
+ self.nest2R2 = {}
344
+ self.nest2R3 = {}
345
+ self.nest2R4 = {}
346
+ self.inv_nest2R = {}
347
+ self.remove_border = {}
348
+
349
+ self.ampnorm = {}
350
+
300
351
  for i in range(nstep_max):
301
- lout=(2**i)
302
- self.pix_interp_val[lout]={}
303
- self.weight_interp_val[lout]={}
352
+ lout = 2**i
353
+ self.pix_interp_val[lout] = {}
354
+ self.weight_interp_val[lout] = {}
304
355
  for j in range(nstep_max):
305
- lout2=(2**j)
306
- self.pix_interp_val[lout][lout2]=None
307
- self.weight_interp_val[lout][lout2]=None
308
- self.ring2nest[lout]=None
309
- self.Idx_Neighbours[lout]=None
310
- self.nest2R[lout]=None
311
- self.nest2R1[lout]=None
312
- self.nest2R2[lout]=None
313
- self.nest2R3[lout]=None
314
- self.nest2R4[lout]=None
315
- self.inv_nest2R[lout]=None
316
- self.remove_border[lout]=None
317
- self.ww_CNN_Transpose[lout]=None
318
- self.ww_CNN[lout]=None
319
- self.X_CNN[lout]=None
320
- self.Y_CNN[lout]=None
321
- self.Z_CNN[lout]=None
322
-
323
- self.loss={}
356
+ lout2 = 2**j
357
+ self.pix_interp_val[lout][lout2] = None
358
+ self.weight_interp_val[lout][lout2] = None
359
+ self.ring2nest[lout] = None
360
+ self.Idx_Neighbours[lout] = None
361
+ self.nest2R[lout] = None
362
+ self.nest2R1[lout] = None
363
+ self.nest2R2[lout] = None
364
+ self.nest2R3[lout] = None
365
+ self.nest2R4[lout] = None
366
+ self.inv_nest2R[lout] = None
367
+ self.remove_border[lout] = None
368
+ self.ww_CNN_Transpose[lout] = None
369
+ self.ww_CNN[lout] = None
370
+ self.X_CNN[lout] = None
371
+ self.Y_CNN[lout] = None
372
+ self.Z_CNN[lout] = None
373
+
374
+ self.loss = {}
324
375
 
325
376
  def get_type(self):
326
377
  return self.all_type
327
378
 
328
379
  def get_mpi_type(self):
329
380
  return self.MPI_ALL_TYPE
330
-
381
+
331
382
  # ---------------------------------------------−---------
332
383
  # -- COMPUTE 3X3 INDEX FOR HEALPIX WORK --
333
384
  # ---------------------------------------------−---------
334
- def conv_to_FoCUS(self,x,axis=0):
335
- if self.use_2D and isinstance(x,np.ndarray):
336
- return(self.to_R(x,axis,chans=self.chans))
385
+ def conv_to_FoCUS(self, x, axis=0):
386
+ if self.use_2D and isinstance(x, np.ndarray):
387
+ return self.to_R(x, axis, chans=self.chans)
337
388
  return x
338
389
 
339
- def diffang(self,a,b):
340
- return np.arctan2(np.sin(a)-np.sin(b),np.cos(a)-np.cos(b))
341
-
342
- def corr_idx_wXX(self,x,y):
343
- idx=np.where(x==-1)[0]
344
- res=x
345
- res[idx]=y[idx]
346
- return(res)
390
+ def diffang(self, a, b):
391
+ return np.arctan2(np.sin(a) - np.sin(b), np.cos(a) - np.cos(b))
392
+
393
+ def corr_idx_wXX(self, x, y):
394
+ idx = np.where(x == -1)[0]
395
+ res = x
396
+ res[idx] = y[idx]
397
+ return res
347
398
 
348
399
  # ---------------------------------------------−---------
349
400
  # make the CNN working : index reporjection of the kernel on healpix
350
-
351
- def calc_indices_convol(self,nside,kernel,rotation=None):
352
- to,po=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
353
- x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
354
-
355
- idx=np.argsort((x-1.0)**2+y**2+z**2)[0:kernel]
356
- x0,y0,z0=hp.pix2vec(nside,idx[0],nest=True)
357
- t0,p0=hp.pix2ang(nside,idx[0],nest=True)
358
-
359
- idx=np.argsort((x-x0)**2+(y-y0)**2+(z-z0)**2)[0:kernel]
360
- im=np.ones([12*nside**2])*-1
361
- im[idx]=np.arange(len(idx))
362
-
363
- xc,yc,zc=hp.pix2vec(nside,idx,nest=True)
364
-
365
- xc-=x0
366
- yc-=y0
367
- zc-=z0
368
-
369
- vec=np.concatenate([np.expand_dims(x,-1),
370
- np.expand_dims(y,-1),
371
- np.expand_dims(z,-1)],1)
372
-
373
- indices=np.zeros([12*nside**2*250,2],dtype='int')
374
- weights=np.zeros([12*nside**2*250])
375
- nn=0
376
- for k in range(12*nside*nside):
377
- if k%(nside*nside)==nside*nside-1:
378
- print('Nside=%d KenelSZ=%d %.2f%%'%(nside,kernel,k/(12*nside**2)*100))
379
- if nside<4:
380
- idx2=np.arange(12*nside**2)
401
+
402
+ def calc_indices_convol(self, nside, kernel, rotation=None):
403
+ to, po = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
404
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside * nside), nest=True)
405
+
406
+ idx = np.argsort((x - 1.0) ** 2 + y**2 + z**2)[0:kernel]
407
+ x0, y0, z0 = hp.pix2vec(nside, idx[0], nest=True)
408
+ t0, p0 = hp.pix2ang(nside, idx[0], nest=True)
409
+
410
+ idx = np.argsort((x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2)[0:kernel]
411
+ im = np.ones([12 * nside**2]) * -1
412
+ im[idx] = np.arange(len(idx))
413
+
414
+ xc, yc, zc = hp.pix2vec(nside, idx, nest=True)
415
+
416
+ xc -= x0
417
+ yc -= y0
418
+ zc -= z0
419
+
420
+ vec = np.concatenate(
421
+ [np.expand_dims(x, -1), np.expand_dims(y, -1), np.expand_dims(z, -1)], 1
422
+ )
423
+
424
+ indices = np.zeros([12 * nside**2 * 250, 2], dtype="int")
425
+ weights = np.zeros([12 * nside**2 * 250])
426
+ nn = 0
427
+ for k in range(12 * nside * nside):
428
+ if k % (nside * nside) == nside * nside - 1:
429
+ print(
430
+ "Nside=%d KenelSZ=%d %.2f%%"
431
+ % (nside, kernel, k / (12 * nside**2) * 100)
432
+ )
433
+ if nside < 4:
434
+ idx2 = np.arange(12 * nside**2)
381
435
  else:
382
- idx2=hp.query_disc(nside, vec[k], np.pi/nside, inclusive=True,nest=True)
383
- t2,p2=hp.pix2ang(nside,idx2,nest=True)
436
+ idx2 = hp.query_disc(
437
+ nside, vec[k], np.pi / nside, inclusive=True, nest=True
438
+ )
439
+ t2, p2 = hp.pix2ang(nside, idx2, nest=True)
384
440
  if rotation is None:
385
- rot=[po[k]/np.pi*180.0,(t0-to[k])/np.pi*180.0]
441
+ rot = [po[k] / np.pi * 180.0, (t0 - to[k]) / np.pi * 180.0]
386
442
  else:
387
- rot=[po[k]/np.pi*180.0,(t0-to[k])/np.pi*180.0,rotation[k]]
388
-
389
- r=hp.Rotator(rot=rot)
390
- t2,p2=r(t2,p2)
391
-
392
- ii,ww=hp.get_interp_weights(nside,t2,p2,nest=True)
393
-
394
- ii=im[ii]
395
-
396
- for l in range(4):
397
- iii=np.where(ii[l]!=-1)[0]
398
- if len(iii)>0:
399
- indices[nn:nn+len(iii),1]=idx2[iii]
400
- indices[nn:nn+len(iii),0]=k*kernel+ii[l,iii]
401
- weights[nn:nn+len(iii)]=ww[l,iii]
402
- nn+=len(iii)
403
-
404
- indices=indices[0:nn]
405
- weights=weights[0:nn]
406
- if k%(nside*nside)==nside*nside-1:
407
- print('Nside=%d KenelSZ=%d Total Number of value=%d Ratio of the matrix %.2g%%'%(nside,
408
- kernel,
409
- nn,
410
- 100*nn/(kernel*12*nside**2*12*nside**2)))
411
- return indices,weights,xc,yc,zc
443
+ rot = [po[k] / np.pi * 180.0, (t0 - to[k]) / np.pi * 180.0, rotation[k]]
444
+
445
+ r = hp.Rotator(rot=rot)
446
+ t2, p2 = r(t2, p2)
447
+
448
+ ii, ww = hp.get_interp_weights(nside, t2, p2, nest=True)
449
+
450
+ ii = im[ii]
451
+
452
+ for l_rotation in range(4):
453
+ iii = np.where(ii[l_rotation] != -1)[0]
454
+ if len(iii) > 0:
455
+ indices[nn : nn + len(iii), 1] = idx2[iii]
456
+ indices[nn : nn + len(iii), 0] = k * kernel + ii[l_rotation, iii]
457
+ weights[nn : nn + len(iii)] = ww[l_rotation, iii]
458
+ nn += len(iii)
459
+
460
+ indices = indices[0:nn]
461
+ weights = weights[0:nn]
462
+ if k % (nside * nside) == nside * nside - 1:
463
+ print(
464
+ "Nside=%d KenelSZ=%d Total Number of value=%d Ratio of the matrix %.2g%%"
465
+ % (
466
+ nside,
467
+ kernel,
468
+ nn,
469
+ 100 * nn / (kernel * 12 * nside**2 * 12 * nside**2),
470
+ )
471
+ )
472
+ return indices, weights, xc, yc, zc
412
473
 
413
474
  # ---------------------------------------------−---------
414
- def calc_orientation(self,im): # im is [Ndata,12*Nside**2]
415
- nside=int(np.sqrt(im.shape[1]//12))
416
- l_kernel=self.KERNELSZ*self.KERNELSZ
417
- norient=32
418
- w=np.zeros([l_kernel,1,2*norient])
419
- ca=np.cos(np.arange(norient)/norient*np.pi)
420
- sa=np.sin(np.arange(norient)/norient*np.pi)
421
- stat=np.zeros([12*nside**2,norient])
422
-
475
+ def calc_orientation(self, im): # im is [Ndata,12*Nside**2]
476
+ nside = int(np.sqrt(im.shape[1] // 12))
477
+ l_kernel = self.KERNELSZ * self.KERNELSZ
478
+ norient = 32
479
+ w = np.zeros([l_kernel, 1, 2 * norient])
480
+ ca = np.cos(np.arange(norient) / norient * np.pi)
481
+ sa = np.sin(np.arange(norient) / norient * np.pi)
482
+ stat = np.zeros([12 * nside**2, norient])
483
+
423
484
  if self.ww_CNN[nside] is None:
424
- self.init_CNN_index(nside,transpose=False)
425
-
426
- y=self.Y_CNN[nside]
427
- z=self.Z_CNN[nside]
428
-
485
+ self.init_CNN_index(nside, transpose=False)
486
+
487
+ y = self.Y_CNN[nside]
488
+ z = self.Z_CNN[nside]
489
+
429
490
  for k in range(norient):
430
- w[:,0,k]=(np.exp(-0.5*nside**2*((y)**2+(z)**2))*np.cos(nside*(y*ca[k]+z*sa[k])*np.pi/2))
431
- w[:,0,k+norient]=(np.exp(-0.5*nside**2*((y)**2+(z)**2))*np.sin(nside*(y*ca[k]+z*sa[k])*np.pi/2))
432
- w[:,0,k]=w[:,0,k]-np.mean(w[:,0,k])
433
- w[:,0,k+norient]=w[:,0,k]-np.mean(w[:,0,k+norient])
491
+ w[:, 0, k] = np.exp(-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)) * np.cos(
492
+ nside * (y * ca[k] + z * sa[k]) * np.pi / 2
493
+ )
494
+ w[:, 0, k + norient] = np.exp(
495
+ -0.5 * nside**2 * ((y) ** 2 + (z) ** 2)
496
+ ) * np.sin(nside * (y * ca[k] + z * sa[k]) * np.pi / 2)
497
+ w[:, 0, k] = w[:, 0, k] - np.mean(w[:, 0, k])
498
+ w[:, 0, k + norient] = w[:, 0, k] - np.mean(w[:, 0, k + norient])
434
499
 
435
500
  for k in range(im.shape[0]):
436
- tmp=im[k].reshape(12*nside**2,1)
437
- im2=self.healpix_layer(tmp,w)
438
- stat=stat+im2[:,0:norient]**2+im2[:,norient:]**2
439
-
440
- rotation=(np.argmax(stat,1)).astype('float')/32.*180.0
441
-
442
- indices,weights,x,y,z=self.calc_indices_convol(nside,9,rotation=rotation)
443
-
444
- return indices,weights
445
-
446
- def init_CNN_index(self,nside,transpose=False):
447
- l_kernel=int(self.KERNELSZ*self.KERNELSZ)
501
+ tmp = im[k].reshape(12 * nside**2, 1)
502
+ im2 = self.healpix_layer(tmp, w)
503
+ stat = stat + im2[:, 0:norient] ** 2 + im2[:, norient:] ** 2
504
+
505
+ rotation = (np.argmax(stat, 1)).astype("float") / 32.0 * 180.0
506
+
507
+ indices, weights, x, y, z = self.calc_indices_convol(
508
+ nside, 9, rotation=rotation
509
+ )
510
+
511
+ return indices, weights
512
+
513
+ def init_CNN_index(self, nside, transpose=False):
514
+ l_kernel = int(self.KERNELSZ * self.KERNELSZ)
448
515
  try:
449
- indices=np.load('%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
450
- weights=np.load('%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
451
- xc=np.load('%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
452
- yc=np.load('%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
453
- zc=np.load('%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
516
+ indices = np.load(
517
+ "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
518
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
519
+ )
520
+ weights = np.load(
521
+ "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
522
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
523
+ )
524
+ xc = np.load(
525
+ "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
526
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
527
+ )
528
+ yc = np.load(
529
+ "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
530
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
531
+ )
532
+ zc = np.load(
533
+ "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
534
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
535
+ )
454
536
  except:
455
- indices,weights,xc,yc,zc=self.calc_indices_convol(nside,l_kernel)
456
- np.save('%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),indices)
457
- np.save('%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),weights)
458
- np.save('%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),xc)
459
- np.save('%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),yc)
460
- np.save('%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),zc)
537
+ indices, weights, xc, yc, zc = self.calc_indices_convol(nside, l_kernel)
538
+ np.save(
539
+ "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
540
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
541
+ indices,
542
+ )
543
+ np.save(
544
+ "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
545
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
546
+ weights,
547
+ )
548
+ np.save(
549
+ "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
550
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
551
+ xc,
552
+ )
553
+ np.save(
554
+ "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
555
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
556
+ yc,
557
+ )
558
+ np.save(
559
+ "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
560
+ % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
561
+ zc,
562
+ )
461
563
  if not self.silent:
462
- print('Write %s/FOSCAT_%s_W%d_%d_%d_CNNV2.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
463
-
464
- self.X_CNN[nside]=xc
465
- self.Y_CNN[nside]=yc
466
- self.Z_CNN[nside]=zc
467
- self.ww_CNN[nside]=self.backend.bk_SparseTensor(indices,
468
- weights,[12*nside*nside*l_kernel,
469
- 12*nside*nside])
470
-
564
+ print(
565
+ "Write %s/FOSCAT_%s_W%d_%d_%d_CNNV2.npy"
566
+ % (
567
+ self.TEMPLATE_PATH,
568
+ TMPFILE_VERSION,
569
+ l_kernel,
570
+ self.NORIENT,
571
+ nside,
572
+ )
573
+ )
574
+
575
+ self.X_CNN[nside] = xc
576
+ self.Y_CNN[nside] = yc
577
+ self.Z_CNN[nside] = zc
578
+ self.ww_CNN[nside] = self.backend.bk_SparseTensor(
579
+ indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
580
+ )
581
+
471
582
  # ---------------------------------------------−---------
472
- def healpix_layer_coord(self,im,axis=0):
473
- nside=int(np.sqrt(im.shape[axis]//12))
474
- l_kernel=self.KERNELSZ*self.KERNELSZ
583
+ def healpix_layer_coord(self, im, axis=0):
584
+ nside = int(np.sqrt(im.shape[axis] // 12))
475
585
  if self.ww_CNN[nside] is None:
476
586
  self.init_CNN_index(nside)
477
- return self.X_CNN[nside],self.Y_CNN[nside],self.Z_CNN[nside]
587
+ return self.X_CNN[nside], self.Y_CNN[nside], self.Z_CNN[nside]
478
588
 
479
589
  # ---------------------------------------------−---------
480
- def healpix_layer_transpose(self,im,ww,indices=None,weights=None,axis=0):
481
- nside=int(np.sqrt(im.shape[axis]//12))
482
- l_kernel=self.KERNELSZ*self.KERNELSZ
483
-
484
- if im.shape[1+axis]!=ww.shape[1]:
590
+ def healpix_layer_transpose(self, im, ww, indices=None, weights=None, axis=0):
591
+ nside = int(np.sqrt(im.shape[axis] // 12))
592
+
593
+ if im.shape[1 + axis] != ww.shape[1]:
485
594
  if not self.silent:
486
- print('Weights channels should be equal to the input image channels')
595
+ print("Weights channels should be equal to the input image channels")
487
596
  return -1
488
- if axis==1:
489
- results=[]
490
-
597
+ if axis == 1:
598
+ results = []
599
+
491
600
  for k in range(im.shape[0]):
492
-
493
- tmp=self.healpix_layer(im[k],ww,indices=indices,weights=weights,axis=0)
494
- tmp=self.backend.bk_reshape(self.up_grade(tmp,2*nside),[12*4*nside**2,ww.shape[2]])
495
-
601
+
602
+ tmp = self.healpix_layer(
603
+ im[k], ww, indices=indices, weights=weights, axis=0
604
+ )
605
+ tmp = self.backend.bk_reshape(
606
+ self.up_grade(tmp, 2 * nside), [12 * 4 * nside**2, ww.shape[2]]
607
+ )
608
+
496
609
  results.append(tmp)
497
-
498
- return self.backend.bk_stack(results,axis=0)
610
+
611
+ return self.backend.bk_stack(results, axis=0)
499
612
  else:
500
- tmp=self.healpix_layer(im,ww,indices=indices,weights=weights,axis=axis)
501
-
502
- return self.up_grade(tmp,2*nside)
503
-
613
+ tmp = self.healpix_layer(
614
+ im, ww, indices=indices, weights=weights, axis=axis
615
+ )
616
+
617
+ return self.up_grade(tmp, 2 * nside)
618
+
504
619
  # ---------------------------------------------−---------
505
620
  # ---------------------------------------------−---------
506
- def healpix_layer(self,im,ww,indices=None,weights=None,axis=0):
507
- nside=int(np.sqrt(im.shape[axis]//12))
508
- l_kernel=self.KERNELSZ*self.KERNELSZ
621
+ def healpix_layer(self, im, ww, indices=None, weights=None, axis=0):
622
+ nside = int(np.sqrt(im.shape[axis] // 12))
623
+ l_kernel = self.KERNELSZ * self.KERNELSZ
509
624
 
510
- if im.shape[1+axis]!=ww.shape[1]:
625
+ if im.shape[1 + axis] != ww.shape[1]:
511
626
  if not self.silent:
512
- print('Weights channels should be equal to the input image channels')
627
+ print("Weights channels should be equal to the input image channels")
513
628
  return -1
514
629
 
515
630
  if indices is None:
516
631
  if self.ww_CNN[nside] is None:
517
- self.init_CNN_index(nside,transpose=False)
518
- mat=self.ww_CNN[nside]
632
+ self.init_CNN_index(nside, transpose=False)
633
+ mat = self.ww_CNN[nside]
519
634
  else:
520
635
  if weights is None:
521
- print('healpix_layer : If indices is not none weights should be specify')
636
+ print(
637
+ "healpix_layer : If indices is not none weights should be specify"
638
+ )
522
639
  return 0
523
-
524
- mat=self.backend.bk_SparseTensor(indices,weights,[12*nside*nside*l_kernel,12*nside*nside])
525
-
526
- if axis==1:
527
- results=[]
528
-
640
+
641
+ mat = self.backend.bk_SparseTensor(
642
+ indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
643
+ )
644
+
645
+ if axis == 1:
646
+ results = []
647
+
529
648
  for k in range(im.shape[0]):
530
-
531
- tmp=self.backend.bk_sparse_dense_matmul(mat,im[k])
532
-
533
- density=self.backend.bk_reshape(tmp,[12*nside*nside,l_kernel*im.shape[1+axis]])
534
-
535
- density=self.backend.bk_matmul(density,self.backend.bk_reshape(ww,[l_kernel*im.shape[1+axis],ww.shape[2]]))
536
-
537
- results.append(self.backend.bk_reshape(density,[12*nside**2,ww.shape[2]]))
538
-
539
- return self.backend.bk_stack(results,axis=0)
649
+
650
+ tmp = self.backend.bk_sparse_dense_matmul(mat, im[k])
651
+
652
+ density = self.backend.bk_reshape(
653
+ tmp, [12 * nside * nside, l_kernel * im.shape[1 + axis]]
654
+ )
655
+
656
+ density = self.backend.bk_matmul(
657
+ density,
658
+ self.backend.bk_reshape(
659
+ ww, [l_kernel * im.shape[1 + axis], ww.shape[2]]
660
+ ),
661
+ )
662
+
663
+ results.append(
664
+ self.backend.bk_reshape(density, [12 * nside**2, ww.shape[2]])
665
+ )
666
+
667
+ return self.backend.bk_stack(results, axis=0)
540
668
  else:
541
- tmp=self.backend.bk_sparse_dense_matmul(mat,im)
542
-
543
- density=self.backend.bk_reshape(tmp,[12*nside*nside,l_kernel*im.shape[1]])
544
-
545
- return self.backend.bk_matmul(density,self.backend.bk_reshape(ww,[l_kernel*im.shape[1],ww.shape[2]]))
669
+ tmp = self.backend.bk_sparse_dense_matmul(mat, im)
670
+
671
+ density = self.backend.bk_reshape(
672
+ tmp, [12 * nside * nside, l_kernel * im.shape[1]]
673
+ )
674
+
675
+ return self.backend.bk_matmul(
676
+ density,
677
+ self.backend.bk_reshape(ww, [l_kernel * im.shape[1], ww.shape[2]]),
678
+ )
679
+
546
680
  # ---------------------------------------------−---------
547
-
681
+
548
682
  # ---------------------------------------------−---------
549
683
  def get_rank(self):
550
- return(self.rank)
684
+ return self.rank
685
+
551
686
  # ---------------------------------------------−---------
552
687
  def get_size(self):
553
- return(self.size)
554
-
688
+ return self.size
689
+
555
690
  # ---------------------------------------------−---------
556
691
  def barrier(self):
557
692
  if self.isMPI:
558
693
  self.comm.Barrier()
559
-
694
+
560
695
  # ---------------------------------------------−---------
561
- def toring(self,image,axis=0):
562
- lout=int(np.sqrt(image.shape[axis]//12))
563
-
696
+ def toring(self, image, axis=0):
697
+ lout = int(np.sqrt(image.shape[axis] // 12))
698
+
564
699
  if self.ring2nest[lout] is None:
565
- self.ring2nest[lout]=hp.ring2nest(lout,np.arange(12*lout**2))
566
-
700
+ self.ring2nest[lout] = hp.ring2nest(lout, np.arange(12 * lout**2))
701
+
567
702
  return image.numpy()[self.ring2nest[lout]]
568
703
 
569
- #--------------------------------------------------------
570
- def ud_grade(self,im,j,axis=0):
571
- rim=im
704
+ # --------------------------------------------------------
705
+ def ud_grade(self, im, j, axis=0):
706
+ rim = im
572
707
  for k in range(j):
573
- rim=self.smooth(rim,axis=axis)
574
- rim=self.ud_grade_2(rim,axis=axis)
708
+ rim = self.smooth(rim, axis=axis)
709
+ rim = self.ud_grade_2(rim, axis=axis)
575
710
  return rim
576
-
577
- #--------------------------------------------------------
578
- def ud_grade_2(self,im,axis=0):
579
-
711
+
712
+ # --------------------------------------------------------
713
+ def ud_grade_2(self, im, axis=0):
714
+
580
715
  if self.use_2D:
581
- ishape=list(im.shape)
582
- if len(ishape)<axis+2:
716
+ ishape = list(im.shape)
717
+ if len(ishape) < axis + 2:
583
718
  if not self.silent:
584
- print('Use of 2D scat with data that has less than 2D')
719
+ print("Use of 2D scat with data that has less than 2D")
585
720
  return None
586
-
587
- npix=im.shape[axis]
588
- npiy=im.shape[axis+1]
589
- odata=1
590
- if len(ishape)>axis+2:
591
- for k in range(axis+2,len(ishape)):
592
- odata=odata*ishape[k]
593
-
594
- ndata=1
721
+
722
+ npix = im.shape[axis]
723
+ npiy = im.shape[axis + 1]
724
+ odata = 1
725
+ if len(ishape) > axis + 2:
726
+ for k in range(axis + 2, len(ishape)):
727
+ odata = odata * ishape[k]
728
+
729
+ ndata = 1
595
730
  for k in range(axis):
596
- ndata=ndata*ishape[k]
731
+ ndata = ndata * ishape[k]
597
732
 
598
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
599
- tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),0:2*(npiy//2),:],[ndata,npix//2,2,npiy//2,2,odata])
733
+ tim = self.backend.bk_reshape(
734
+ self.backend.bk_cast(im), [ndata, npix, npiy, odata]
735
+ )
736
+ tim = self.backend.bk_reshape(
737
+ tim[:, 0 : 2 * (npix // 2), 0 : 2 * (npiy // 2), :],
738
+ [ndata, npix // 2, 2, npiy // 2, 2, odata],
739
+ )
600
740
 
601
- res=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim,4),2)/4
602
-
603
- if axis==0:
604
- if len(ishape)==2:
605
- return self.backend.bk_reshape(res,[npix//2,npiy//2])
741
+ res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
742
+
743
+ if axis == 0:
744
+ if len(ishape) == 2:
745
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
606
746
  else:
607
- return self.backend.bk_reshape(res,[npix//2,npiy//2]+ishape[axis+2:])
747
+ return self.backend.bk_reshape(
748
+ res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
749
+ )
608
750
  else:
609
- if len(ishape)==axis+2:
610
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2])
751
+ if len(ishape) == axis + 2:
752
+ return self.backend.bk_reshape(
753
+ res, ishape[0:axis] + [npix // 2, npiy // 2]
754
+ )
611
755
  else:
612
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2,npiy//2]+ishape[axis+2:])
613
-
614
- return self.backend.bk_reshape(res,[npix//2,npiy//2])
756
+ return self.backend.bk_reshape(
757
+ res,
758
+ ishape[0:axis] + [npix // 2, npiy // 2] + ishape[axis + 2 :],
759
+ )
760
+
761
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
615
762
  elif self.use_1D:
616
- ishape=list(im.shape)
617
- if len(ishape)<axis+1:
763
+ ishape = list(im.shape)
764
+ if len(ishape) < axis + 1:
618
765
  if not self.silent:
619
- print('Use of 1D scat with data that has less than 1D')
766
+ print("Use of 1D scat with data that has less than 1D")
620
767
  return None
621
-
622
- npix=im.shape[axis]
623
- odata=1
624
- if len(ishape)>axis+1:
625
- for k in range(axis+1,len(ishape)):
626
- odata=odata*ishape[k]
627
-
628
- ndata=1
768
+
769
+ npix = im.shape[axis]
770
+ odata = 1
771
+ if len(ishape) > axis + 1:
772
+ for k in range(axis + 1, len(ishape)):
773
+ odata = odata * ishape[k]
774
+
775
+ ndata = 1
629
776
  for k in range(axis):
630
- ndata=ndata*ishape[k]
777
+ ndata = ndata * ishape[k]
631
778
 
632
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
633
- tim=self.backend.bk_reshape(tim[:,0:2*(npix//2),:],[ndata,npix//2,2,odata])
779
+ tim = self.backend.bk_reshape(
780
+ self.backend.bk_cast(im), [ndata, npix, odata]
781
+ )
782
+ tim = self.backend.bk_reshape(
783
+ tim[:, 0 : 2 * (npix // 2), :], [ndata, npix // 2, 2, odata]
784
+ )
634
785
 
635
- res=self.backend.bk_reduce_mean(tim,2)
636
-
637
- if axis==0:
638
- if len(ishape)==1:
639
- return self.backend.bk_reshape(res,[npix//2])
786
+ res = self.backend.bk_reduce_mean(tim, 2)
787
+
788
+ if axis == 0:
789
+ if len(ishape) == 1:
790
+ return self.backend.bk_reshape(res, [npix // 2])
640
791
  else:
641
- return self.backend.bk_reshape(res,[npix//2]+ishape[axis+1:])
792
+ return self.backend.bk_reshape(
793
+ res, [npix // 2] + ishape[axis + 1 :]
794
+ )
642
795
  else:
643
- if len(ishape)==axis+1:
644
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2])
796
+ if len(ishape) == axis + 1:
797
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2])
645
798
  else:
646
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix//2]+ishape[axis+1:])
647
-
648
- return self.backend.bk_reshape(res,[npix//2])
649
-
799
+ return self.backend.bk_reshape(
800
+ res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
801
+ )
802
+
803
+ return self.backend.bk_reshape(res, [npix // 2])
804
+
650
805
  else:
651
- shape=list(im.shape)
652
-
653
- lout=int(np.sqrt(shape[axis]//12))
654
- if im.__class__==np.zeros([0]).__class__:
655
- oshape=np.zeros([len(shape)+1],dtype='int')
656
- if axis>0:
657
- oshape[0:axis]=shape[0:axis]
658
- oshape[axis]=12*lout*lout//4
659
- oshape[axis+1]=4
660
- if len(shape)>axis:
661
- oshape[axis+2:]=shape[axis+1:]
806
+ shape = list(im.shape)
807
+
808
+ lout = int(np.sqrt(shape[axis] // 12))
809
+ if im.__class__ == np.zeros([0]).__class__:
810
+ oshape = np.zeros([len(shape) + 1], dtype="int")
811
+ if axis > 0:
812
+ oshape[0:axis] = shape[0:axis]
813
+ oshape[axis] = 12 * lout * lout // 4
814
+ oshape[axis + 1] = 4
815
+ if len(shape) > axis:
816
+ oshape[axis + 2 :] = shape[axis + 1 :]
662
817
  else:
663
- if axis>0:
664
- oshape=shape[0:axis]+[12*lout*lout//4,4]
818
+ if axis > 0:
819
+ oshape = shape[0:axis] + [12 * lout * lout // 4, 4]
665
820
  else:
666
- oshape=[12*lout*lout//4,4]
667
- if len(shape)>axis:
668
- oshape=oshape+shape[axis+1:]
669
-
670
- return(self.backend.bk_reduce_mean(self.backend.bk_reshape(im,oshape),axis=axis+1))
671
-
672
- #--------------------------------------------------------
673
- def up_grade(self,im,nout,axis=0,nouty=None):
674
-
821
+ oshape = [12 * lout * lout // 4, 4]
822
+ if len(shape) > axis:
823
+ oshape = oshape + shape[axis + 1 :]
824
+
825
+ return self.backend.bk_reduce_mean(
826
+ self.backend.bk_reshape(im, oshape), axis=axis + 1
827
+ )
828
+
829
+ # --------------------------------------------------------
830
+ def up_grade(self, im, nout, axis=0, nouty=None):
831
+
675
832
  if self.use_2D:
676
- ishape=list(im.shape)
677
- if len(ishape)<axis+2:
833
+ ishape = list(im.shape)
834
+ if len(ishape) < axis + 2:
678
835
  if not self.silent:
679
- print('Use of 2D scat with data that has less than 2D')
836
+ print("Use of 2D scat with data that has less than 2D")
680
837
  return None
681
-
838
+
682
839
  if nouty is None:
683
- nouty=nout
684
-
685
- if ishape[axis]==nout and ishape[axis+1]==nouty:
840
+ nouty = nout
841
+
842
+ if ishape[axis] == nout and ishape[axis + 1] == nouty:
686
843
  return im
687
-
688
- npix=im.shape[axis]
689
- npiy=im.shape[axis+1]
690
- odata=1
691
- if len(ishape)>axis+2:
692
- for k in range(axis+2,len(ishape)):
693
- odata=odata*ishape[k]
694
-
695
- ndata=1
844
+
845
+ npix = im.shape[axis]
846
+ npiy = im.shape[axis + 1]
847
+ odata = 1
848
+ if len(ishape) > axis + 2:
849
+ for k in range(axis + 2, len(ishape)):
850
+ odata = odata * ishape[k]
851
+
852
+ ndata = 1
696
853
  for k in range(axis):
697
- ndata=ndata*ishape[k]
854
+ ndata = ndata * ishape[k]
698
855
 
699
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,npiy,odata])
856
+ tim = self.backend.bk_reshape(
857
+ self.backend.bk_cast(im), [ndata, npix, npiy, odata]
858
+ )
700
859
 
701
- res=self.backend.bk_resize_image(tim,[nout,nouty])
702
-
703
- if axis==0:
704
- if len(ishape)==2:
705
- return self.backend.bk_reshape(res,[nout,nouty])
860
+ res = self.backend.bk_resize_image(tim, [nout, nouty])
861
+
862
+ if axis == 0:
863
+ if len(ishape) == 2:
864
+ return self.backend.bk_reshape(res, [nout, nouty])
706
865
  else:
707
- return self.backend.bk_reshape(res,[nout,nouty]+ishape[axis+2:])
866
+ return self.backend.bk_reshape(
867
+ res, [nout, nouty] + ishape[axis + 2 :]
868
+ )
708
869
  else:
709
- if len(ishape)==axis+2:
710
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty])
870
+ if len(ishape) == axis + 2:
871
+ return self.backend.bk_reshape(res, ishape[0:axis] + [nout, nouty])
711
872
  else:
712
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout,nouty]+ishape[axis+2:])
713
-
714
- return self.backend.bk_reshape(res,[nout,nouty])
873
+ return self.backend.bk_reshape(
874
+ res, ishape[0:axis] + [nout, nouty] + ishape[axis + 2 :]
875
+ )
876
+
877
+ return self.backend.bk_reshape(res, [nout, nouty])
715
878
 
716
879
  elif self.use_1D:
717
- ishape=list(im.shape)
718
- if len(ishape)<axis+1:
880
+ ishape = list(im.shape)
881
+ if len(ishape) < axis + 1:
719
882
  if not self.silent:
720
- print('Use of 1D scat with data that has less than 1D')
883
+ print("Use of 1D scat with data that has less than 1D")
721
884
  return None
722
-
723
- if ishape[axis]==nout:
885
+
886
+ if ishape[axis] == nout:
724
887
  return im
725
-
726
- npix=im.shape[axis]
727
- odata=1
728
- if len(ishape)>axis+1:
729
- for k in range(axis+1,len(ishape)):
730
- odata=odata*ishape[k]
731
-
732
- ndata=1
888
+
889
+ npix = im.shape[axis]
890
+ odata = 1
891
+ if len(ishape) > axis + 1:
892
+ for k in range(axis + 1, len(ishape)):
893
+ odata = odata * ishape[k]
894
+
895
+ ndata = 1
733
896
  for k in range(axis):
734
- ndata=ndata*ishape[k]
735
-
736
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
737
-
738
- while tim.shape[1]!=nout:
739
- res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
740
- res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
741
- tim = self.backend.bk_reshape(self.backend.bk_concat([res1,res2],-2),[ndata,tim.shape[1]*2,odata])
742
-
743
- if axis==0:
744
- if len(ishape)==1:
745
- return self.backend.bk_reshape(tim,[nout])
897
+ ndata = ndata * ishape[k]
898
+
899
+ tim = self.backend.bk_reshape(
900
+ self.backend.bk_cast(im), [ndata, npix, odata]
901
+ )
902
+
903
+ while tim.shape[1] != nout:
904
+ res2 = self.backend.bk_expand_dims(
905
+ self.backend.bk_concat(
906
+ [(tim[:, 1:, :] + 3 * tim[:, :-1, :]) / 4, tim[:, -1:, :]], 1
907
+ ),
908
+ -2,
909
+ )
910
+ res1 = self.backend.bk_expand_dims(
911
+ self.backend.bk_concat(
912
+ [tim[:, 0:1, :], (tim[:, 1:, :] * 3 + tim[:, :-1, :]) / 4], 1
913
+ ),
914
+ -2,
915
+ )
916
+ tim = self.backend.bk_reshape(
917
+ self.backend.bk_concat([res1, res2], -2),
918
+ [ndata, tim.shape[1] * 2, odata],
919
+ )
920
+
921
+ if axis == 0:
922
+ if len(ishape) == 1:
923
+ return self.backend.bk_reshape(tim, [nout])
746
924
  else:
747
- return self.backend.bk_reshape(tim,[nout]+ishape[axis+1:])
925
+ return self.backend.bk_reshape(tim, [nout] + ishape[axis + 1 :])
748
926
  else:
749
- if len(ishape)==axis+1:
750
- return self.backend.bk_reshape(tim,ishape[0:axis]+[nout])
927
+ if len(ishape) == axis + 1:
928
+ return self.backend.bk_reshape(tim, ishape[0:axis] + [nout])
751
929
  else:
752
- return self.backend.bk_reshape(tim,ishape[0:axis]+[nout]+ishape[axis+1:])
753
-
754
- return self.backend.bk_reshape(tim,[nout])
755
-
930
+ return self.backend.bk_reshape(
931
+ tim, ishape[0:axis] + [nout] + ishape[axis + 1 :]
932
+ )
933
+
934
+ return self.backend.bk_reshape(tim, [nout])
935
+
756
936
  else:
757
937
 
758
- lout=int(np.sqrt(im.shape[axis]//12))
759
-
938
+ lout = int(np.sqrt(im.shape[axis] // 12))
939
+
760
940
  if self.pix_interp_val[lout][nout] is None:
761
941
  if not self.silent:
762
- print('compute lout nout',lout,nout)
763
- th,ph=hp.pix2ang(nout,np.arange(12*nout**2,dtype='int'),nest=True)
764
- p, w = hp.get_interp_weights(lout,th,ph,nest=True)
942
+ print("compute lout nout", lout, nout)
943
+ th, ph = hp.pix2ang(
944
+ nout, np.arange(12 * nout**2, dtype="int"), nest=True
945
+ )
946
+ p, w = hp.get_interp_weights(lout, th, ph, nest=True)
765
947
  del th
766
948
  del ph
767
-
768
- indice=np.zeros([12*nout*nout*4,2],dtype='int')
769
- p=p.T
770
- w=w.T
771
- t=np.argsort(p,1).flatten() # to make oder indices for sparsematrix computation
772
- t=(t+np.repeat(np.arange(12*nout*nout)*4,4))
773
- p=p.flatten()[t]
774
- w=w.flatten()[t]
775
- indice[:,0]=np.repeat(np.arange(12*nout**2),4)
776
- indice[:,1]=p
777
-
778
- self.pix_interp_val[lout][nout]=1
779
- self.weight_interp_val[lout][nout] = self.backend.bk_SparseTensor(self.backend.constant(indice), \
780
- self.backend.constant(self.backend.bk_cast(w.flatten())), \
781
- dense_shape=[12*nout**2,12*lout**2])
782
-
783
- if lout==nout:
784
- imout=im
949
+
950
+ indice = np.zeros([12 * nout * nout * 4, 2], dtype="int")
951
+ p = p.T
952
+ w = w.T
953
+ t = np.argsort(
954
+ p, 1
955
+ ).flatten() # to make oder indices for sparsematrix computation
956
+ t = t + np.repeat(np.arange(12 * nout * nout) * 4, 4)
957
+ p = p.flatten()[t]
958
+ w = w.flatten()[t]
959
+ indice[:, 0] = np.repeat(np.arange(12 * nout**2), 4)
960
+ indice[:, 1] = p
961
+
962
+ self.pix_interp_val[lout][nout] = 1
963
+ self.weight_interp_val[lout][nout] = self.backend.bk_SparseTensor(
964
+ self.backend.constant(indice),
965
+ self.backend.constant(self.backend.bk_cast(w.flatten())),
966
+ dense_shape=[12 * nout**2, 12 * lout**2],
967
+ )
968
+
969
+ if lout == nout:
970
+ imout = im
785
971
  else:
786
972
 
787
- ishape=list(im.shape)
788
- odata=1
789
- for k in range(axis+1,len(ishape)):
790
- odata=odata*ishape[k]
791
-
792
- ndata=1
973
+ ishape = list(im.shape)
974
+ odata = 1
975
+ for k in range(axis + 1, len(ishape)):
976
+ odata = odata * ishape[k]
977
+
978
+ ndata = 1
793
979
  for k in range(axis):
794
- ndata=ndata*ishape[k]
795
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,12*lout**2,odata])
796
- if tim.dtype==self.all_cbk_type:
797
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
798
- ,self.backend.bk_real(tim[0])),[1,12*nout**2,odata])
799
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
800
- ,self.backend.bk_imag(tim[0])),[1,12*nout**2,odata])
801
- imout=self.backend.bk_complex(rr,ii)
980
+ ndata = ndata * ishape[k]
981
+ tim = self.backend.bk_reshape(
982
+ self.backend.bk_cast(im), [ndata, 12 * lout**2, odata]
983
+ )
984
+ if tim.dtype == self.all_cbk_type:
985
+ rr = self.backend.bk_reshape(
986
+ self.backend.bk_sparse_dense_matmul(
987
+ self.weight_interp_val[lout][nout],
988
+ self.backend.bk_real(tim[0]),
989
+ ),
990
+ [1, 12 * nout**2, odata],
991
+ )
992
+ ii = self.backend.bk_reshape(
993
+ self.backend.bk_sparse_dense_matmul(
994
+ self.weight_interp_val[lout][nout],
995
+ self.backend.bk_imag(tim[0]),
996
+ ),
997
+ [1, 12 * nout**2, odata],
998
+ )
999
+ imout = self.backend.bk_complex(rr, ii)
802
1000
  else:
803
- imout=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
804
- ,tim[0]),[1,12*nout**2,odata])
805
-
806
- for k in range(1,ndata):
807
- if tim.dtype==self.all_cbk_type:
808
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
809
- ,self.backend.bk_real(tim[k])),[1,12*nout**2,odata])
810
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
811
- ,self.backend.bk_imag(tim[k])),[1,12*nout**2,odata])
812
- imout=self.backend.bk_concat([imout,self.backend.bk_complex(rr,ii)],0)
1001
+ imout = self.backend.bk_reshape(
1002
+ self.backend.bk_sparse_dense_matmul(
1003
+ self.weight_interp_val[lout][nout], tim[0]
1004
+ ),
1005
+ [1, 12 * nout**2, odata],
1006
+ )
1007
+
1008
+ for k in range(1, ndata):
1009
+ if tim.dtype == self.all_cbk_type:
1010
+ rr = self.backend.bk_reshape(
1011
+ self.backend.bk_sparse_dense_matmul(
1012
+ self.weight_interp_val[lout][nout],
1013
+ self.backend.bk_real(tim[k]),
1014
+ ),
1015
+ [1, 12 * nout**2, odata],
1016
+ )
1017
+ ii = self.backend.bk_reshape(
1018
+ self.backend.bk_sparse_dense_matmul(
1019
+ self.weight_interp_val[lout][nout],
1020
+ self.backend.bk_imag(tim[k]),
1021
+ ),
1022
+ [1, 12 * nout**2, odata],
1023
+ )
1024
+ imout = self.backend.bk_concat(
1025
+ [imout, self.backend.bk_complex(rr, ii)], 0
1026
+ )
813
1027
  else:
814
- imout=self.backend.bk_concat([imout,self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(self.weight_interp_val[lout][nout]
815
- ,tim[k]),[1,12*nout**2,odata])],0)
816
-
817
- if axis==0:
818
- if len(ishape)==1:
819
- return self.backend.bk_reshape(imout,[12*nout**2])
1028
+ imout = self.backend.bk_concat(
1029
+ [
1030
+ imout,
1031
+ self.backend.bk_reshape(
1032
+ self.backend.bk_sparse_dense_matmul(
1033
+ self.weight_interp_val[lout][nout], tim[k]
1034
+ ),
1035
+ [1, 12 * nout**2, odata],
1036
+ ),
1037
+ ],
1038
+ 0,
1039
+ )
1040
+
1041
+ if axis == 0:
1042
+ if len(ishape) == 1:
1043
+ return self.backend.bk_reshape(imout, [12 * nout**2])
820
1044
  else:
821
- return self.backend.bk_reshape(imout,[12*nout**2]+ishape[axis+1:])
1045
+ return self.backend.bk_reshape(
1046
+ imout, [12 * nout**2] + ishape[axis + 1 :]
1047
+ )
822
1048
  else:
823
- if len(ishape)==axis+1:
824
- return self.backend.bk_reshape(imout,ishape[0:axis]+[12*nout**2])
1049
+ if len(ishape) == axis + 1:
1050
+ return self.backend.bk_reshape(
1051
+ imout, ishape[0:axis] + [12 * nout**2]
1052
+ )
825
1053
  else:
826
- return self.backend.bk_reshape(imout,ishape[0:axis]+[12*nout**2]+ishape[axis+1:])
827
- return(imout)
828
-
829
- #--------------------------------------------------------
830
- def fill_1d(self,i_arr,nullval=0):
831
- arr=i_arr.copy()
1054
+ return self.backend.bk_reshape(
1055
+ imout, ishape[0:axis] + [12 * nout**2] + ishape[axis + 1 :]
1056
+ )
1057
+ return imout
1058
+
1059
+ # --------------------------------------------------------
1060
+ def fill_1d(self, i_arr, nullval=0):
1061
+ arr = i_arr.copy()
832
1062
  # Indices des éléments non nuls
833
- non_zero_indices = np.where(arr!=nullval)[0]
834
-
1063
+ non_zero_indices = np.where(arr != nullval)[0]
1064
+
835
1065
  # Indices de tous les éléments
836
1066
  all_indices = np.arange(len(arr))
837
-
1067
+
838
1068
  # Interpoler linéairement en utilisant np.interp
839
1069
  # np.interp(x, xp, fp) : x sont les indices pour lesquels on veut obtenir des valeurs
840
1070
  # xp sont les indices des données existantes, fp sont les valeurs des données existantes
841
- interpolated_values = np.interp(all_indices, non_zero_indices, arr[non_zero_indices])
842
-
1071
+ interpolated_values = np.interp(
1072
+ all_indices, non_zero_indices, arr[non_zero_indices]
1073
+ )
1074
+
843
1075
  # Mise à jour du tableau original
844
- arr[arr==nullval] = interpolated_values[arr==nullval]
845
-
1076
+ arr[arr == nullval] = interpolated_values[arr == nullval]
1077
+
846
1078
  return arr
847
1079
 
848
- def fill_2d(self,i_arr,nullval=0):
849
- arr=i_arr.copy()
1080
+ def fill_2d(self, i_arr, nullval=0):
1081
+ arr = i_arr.copy()
850
1082
  # Créer une grille de coordonnées correspondant aux indices du tableau
851
1083
  x, y = np.indices(arr.shape)
852
-
1084
+
853
1085
  # Extraire les coordonnées des points non nuls ainsi que leurs valeurs
854
1086
  non_zero_points = np.array((x[arr != nullval], y[arr != nullval])).T
855
1087
  non_zero_values = arr[arr != nullval]
856
-
1088
+
857
1089
  # Extraire les coordonnées des points nuls
858
1090
  zero_points = np.array((x[arr == nullval], y[arr == nullval])).T
859
1091
 
860
1092
  # Interpolation linéaire
861
- interpolated_values = griddata(non_zero_points, non_zero_values, zero_points, method='linear')
1093
+ interpolated_values = griddata(
1094
+ non_zero_points, non_zero_values, zero_points, method="linear"
1095
+ )
862
1096
 
863
1097
  # Remplacer les valeurs nulles par les valeurs interpolées
864
1098
  arr[arr == nullval] = interpolated_values
865
1099
 
866
1100
  return arr
867
-
868
- def fill_healpy(self,i_map,nmax=10,nullval=hp.UNSEEN):
869
- map=1*i_map
1101
+
1102
+ def fill_healpy(self, i_map, nmax=10, nullval=hp.UNSEEN):
1103
+ map = 1 * i_map
870
1104
  # Trouver les pixels nuls
871
1105
  nside = hp.npix2nside(len(map))
872
1106
  null_indices = np.where(map == nullval)[0]
873
-
874
- itt=0
875
- while null_indices.shape[0]>0 and itt<nmax:
1107
+
1108
+ itt = 0
1109
+ while null_indices.shape[0] > 0 and itt < nmax:
876
1110
  # Trouver les coordonnées theta, phi pour les pixels nuls
877
1111
  theta, phi = hp.pix2ang(nside, null_indices)
878
-
1112
+
879
1113
  # Interpoler les valeurs en utilisant les pixels voisins
880
1114
  # La fonction get_interp_val peut être utilisée pour obtenir les valeurs interpolées
881
1115
  # pour des positions données en theta et phi.
882
1116
  i_idx = hp.get_all_neighbours(nside, theta, phi)
883
-
884
- i_w=(map[i_idx]!=nullval)*(i_idx!=-1)
885
- vv=np.sum(i_w,0)
886
- interpolated_values=np.sum(i_w*map[i_idx],0)
1117
+
1118
+ i_w = (map[i_idx] != nullval) * (i_idx != -1)
1119
+ vv = np.sum(i_w, 0)
1120
+ interpolated_values = np.sum(i_w * map[i_idx], 0)
887
1121
 
888
1122
  # Remplacer les valeurs nulles par les valeurs interpolées
889
- map[null_indices[vv>0]] = interpolated_values[vv>0]/vv[vv>0]
1123
+ map[null_indices[vv > 0]] = interpolated_values[vv > 0] / vv[vv > 0]
890
1124
 
891
1125
  null_indices = np.where(map == nullval)[0]
892
- itt+=1
893
-
1126
+ itt += 1
1127
+
894
1128
  return map
895
-
896
- #--------------------------------------------------------
897
- def ud_grade_1d(self,im,nout,axis=0):
898
- npix=im.shape[axis]
899
-
900
- ishape=list(im.shape)
901
- odata=1
902
- for k in range(axis+1,len(ishape)):
903
- odata=odata*ishape[k]
904
-
905
- ndata=1
1129
+
1130
+ # --------------------------------------------------------
1131
+ def ud_grade_1d(self, im, nout, axis=0):
1132
+ npix = im.shape[axis]
1133
+
1134
+ ishape = list(im.shape)
1135
+ odata = 1
1136
+ for k in range(axis + 1, len(ishape)):
1137
+ odata = odata * ishape[k]
1138
+
1139
+ ndata = 1
906
1140
  for k in range(axis):
907
- ndata=ndata*ishape[k]
1141
+ ndata = ndata * ishape[k]
908
1142
 
909
- nscale=npix//nout
910
- if npix%nscale==0:
911
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix//nscale,nscale,odata])
1143
+ nscale = npix // nout
1144
+ if npix % nscale == 0:
1145
+ tim = self.backend.bk_reshape(
1146
+ self.backend.bk_cast(im), [ndata, npix // nscale, nscale, odata]
1147
+ )
912
1148
  else:
913
- im=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
914
- tim=self.backend.bk_reshape(self.backend.bk_cast(im[:,0:nscale*(npix//nscale),:]),[ndata,npix//nscale,nscale,odata])
915
- res = self.backend.bk_reduce_mean(tim,2)
916
-
917
- if axis==0:
918
- if len(ishape)==1:
919
- return self.backend.bk_reshape(res,[nout])
1149
+ im = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1150
+ tim = self.backend.bk_reshape(
1151
+ self.backend.bk_cast(im[:, 0 : nscale * (npix // nscale), :]),
1152
+ [ndata, npix // nscale, nscale, odata],
1153
+ )
1154
+ res = self.backend.bk_reduce_mean(tim, 2)
1155
+
1156
+ if axis == 0:
1157
+ if len(ishape) == 1:
1158
+ return self.backend.bk_reshape(res, [nout])
920
1159
  else:
921
- return self.backend.bk_reshape(res,[nout]+ishape[axis+1:])
1160
+ return self.backend.bk_reshape(res, [nout] + ishape[axis + 1 :])
922
1161
  else:
923
- if len(ishape)==axis+1:
924
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout])
1162
+ if len(ishape) == axis + 1:
1163
+ return self.backend.bk_reshape(res, ishape[0:axis] + [nout])
925
1164
  else:
926
- return self.backend.bk_reshape(res,ishape[0:axis]+[nout]+ishape[axis+1:])
927
- return self.backend.bk_reshape(res,[nout])
928
-
929
- #--------------------------------------------------------
930
- def up_grade_2_1d(self,im,axis=0):
931
-
932
- npix=im.shape[axis]
933
-
934
- ishape=list(im.shape)
935
- odata=1
936
- for k in range(axis+1,len(ishape)):
937
- odata=odata*ishape[k]
938
-
939
- ndata=1
1165
+ return self.backend.bk_reshape(
1166
+ res, ishape[0:axis] + [nout] + ishape[axis + 1 :]
1167
+ )
1168
+ return self.backend.bk_reshape(res, [nout])
1169
+
1170
+ # --------------------------------------------------------
1171
+ def up_grade_2_1d(self, im, axis=0):
1172
+
1173
+ npix = im.shape[axis]
1174
+
1175
+ ishape = list(im.shape)
1176
+ odata = 1
1177
+ for k in range(axis + 1, len(ishape)):
1178
+ odata = odata * ishape[k]
1179
+
1180
+ ndata = 1
940
1181
  for k in range(axis):
941
- ndata=ndata*ishape[k]
942
-
943
- tim=self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
944
-
945
- res2=self.backend.bk_expand_dims(self.backend.bk_concat([(tim[:,1:,:]+3*tim[:,:-1,:])/4,tim[:,-1:,:]],1),-2)
946
- res1=self.backend.bk_expand_dims(self.backend.bk_concat([tim[:,0:1,:],(tim[:,1:,:]*3+tim[:,:-1,:])/4],1),-2)
947
- res = self.backend.bk_concat([res1,res2],-2)
948
-
949
- if axis==0:
950
- if len(ishape)==1:
951
- return self.backend.bk_reshape(res,[npix*2])
1182
+ ndata = ndata * ishape[k]
1183
+
1184
+ tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1185
+
1186
+ res2 = self.backend.bk_expand_dims(
1187
+ self.backend.bk_concat(
1188
+ [(tim[:, 1:, :] + 3 * tim[:, :-1, :]) / 4, tim[:, -1:, :]], 1
1189
+ ),
1190
+ -2,
1191
+ )
1192
+ res1 = self.backend.bk_expand_dims(
1193
+ self.backend.bk_concat(
1194
+ [tim[:, 0:1, :], (tim[:, 1:, :] * 3 + tim[:, :-1, :]) / 4], 1
1195
+ ),
1196
+ -2,
1197
+ )
1198
+ res = self.backend.bk_concat([res1, res2], -2)
1199
+
1200
+ if axis == 0:
1201
+ if len(ishape) == 1:
1202
+ return self.backend.bk_reshape(res, [npix * 2])
952
1203
  else:
953
- return self.backend.bk_reshape(res,[npix*2]+ishape[axis+1:])
1204
+ return self.backend.bk_reshape(res, [npix * 2] + ishape[axis + 1 :])
954
1205
  else:
955
- if len(ishape)==axis+1:
956
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix*2])
1206
+ if len(ishape) == axis + 1:
1207
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix * 2])
957
1208
  else:
958
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix*2]+ishape[axis+1:])
959
- return self.backend.bk_reshape(res,[npix*2])
960
-
961
-
962
- #--------------------------------------------------------
963
- def convol_1d(self,im,axis=0):
964
-
965
- xx=np.arange(5)-2
966
- w=np.exp(-0.17328679514*(xx)**2)
967
- c=np.cos((xx)*np.pi/2)
968
- s=np.sin((xx)*np.pi/2)
969
-
970
- wr=np.array(w*c).reshape(xx.shape[0],1,1)
971
- wi=np.array(w*s).reshape(xx.shape[0],1,1)
972
-
973
- npix=im.shape[axis]
974
-
975
- ishape=list(im.shape)
976
- odata=1
977
- for k in range(axis+1,len(ishape)):
978
- odata=odata*ishape[k]
979
-
980
- ndata=1
1209
+ return self.backend.bk_reshape(
1210
+ res, ishape[0:axis] + [npix * 2] + ishape[axis + 1 :]
1211
+ )
1212
+ return self.backend.bk_reshape(res, [npix * 2])
1213
+
1214
+ # --------------------------------------------------------
1215
+ def convol_1d(self, im, axis=0):
1216
+
1217
+ xx = np.arange(5) - 2
1218
+ w = np.exp(-0.17328679514 * (xx) ** 2)
1219
+ c = np.cos((xx) * np.pi / 2)
1220
+ s = np.sin((xx) * np.pi / 2)
1221
+
1222
+ wr = np.array(w * c).reshape(xx.shape[0], 1, 1)
1223
+ wi = np.array(w * s).reshape(xx.shape[0], 1, 1)
1224
+
1225
+ npix = im.shape[axis]
1226
+
1227
+ ishape = list(im.shape)
1228
+ odata = 1
1229
+ for k in range(axis + 1, len(ishape)):
1230
+ odata = odata * ishape[k]
1231
+
1232
+ ndata = 1
981
1233
  for k in range(axis):
982
- ndata=ndata*ishape[k]
983
-
984
- if odata>1:
985
- wr=np.repeat(wr,odata,2)
986
- wi=np.repeat(wi,odata,2)
987
-
988
- wr=self.backend.bk_cast(self.backend.constant(wr))
989
- wi=self.backend.bk_cast(self.backend.constant(wi))
990
-
991
- tim = self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
992
-
993
- if tim.dtype==self.all_cbk_type:
994
- rr1 = self.backend.bk_conv1d(self.backend.bk_real(tim),wr)
995
- ii1 = self.backend.bk_conv1d(self.backend.bk_real(tim),wi)
996
- rr2 = self.backend.bk_conv1d(self.backend.bk_imag(tim),wr)
997
- ii2 = self.backend.bk_conv1d(self.backend.bk_imag(tim),wi)
998
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
1234
+ ndata = ndata * ishape[k]
1235
+
1236
+ if odata > 1:
1237
+ wr = np.repeat(wr, odata, 2)
1238
+ wi = np.repeat(wi, odata, 2)
1239
+
1240
+ wr = self.backend.bk_cast(self.backend.constant(wr))
1241
+ wi = self.backend.bk_cast(self.backend.constant(wi))
1242
+
1243
+ tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1244
+
1245
+ if tim.dtype == self.all_cbk_type:
1246
+ rr1 = self.backend.bk_conv1d(self.backend.bk_real(tim), wr)
1247
+ ii1 = self.backend.bk_conv1d(self.backend.bk_real(tim), wi)
1248
+ rr2 = self.backend.bk_conv1d(self.backend.bk_imag(tim), wr)
1249
+ ii2 = self.backend.bk_conv1d(self.backend.bk_imag(tim), wi)
1250
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
999
1251
  else:
1000
- rr = self.backend.bk_conv1d(tim,wr)
1001
- ii = self.backend.bk_conv1d(tim,wi)
1002
-
1003
- res=self.backend.bk_complex(rr,ii)
1004
-
1005
- if axis==0:
1006
- if len(ishape)==1:
1007
- return self.backend.bk_reshape(res,[npix])
1252
+ rr = self.backend.bk_conv1d(tim, wr)
1253
+ ii = self.backend.bk_conv1d(tim, wi)
1254
+
1255
+ res = self.backend.bk_complex(rr, ii)
1256
+
1257
+ if axis == 0:
1258
+ if len(ishape) == 1:
1259
+ return self.backend.bk_reshape(res, [npix])
1008
1260
  else:
1009
- return self.backend.bk_reshape(res,[npix]+ishape[axis+1:])
1261
+ return self.backend.bk_reshape(res, [npix] + ishape[axis + 1 :])
1010
1262
  else:
1011
- if len(ishape)==axis+1:
1012
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix])
1263
+ if len(ishape) == axis + 1:
1264
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix])
1013
1265
  else:
1014
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix]+ishape[axis+1:])
1015
- return self.backend.bk_reshape(res,[npix])
1016
-
1017
-
1018
- #--------------------------------------------------------
1019
- def smooth_1d(self,im,axis=0):
1020
-
1021
- xx=np.arange(5)-2
1022
- w=np.exp(-0.17328679514*(xx)**2)
1023
- w=w/w.sum()
1024
- w=np.array(w).reshape(xx.shape[0],1,1)
1025
-
1026
- npix=im.shape[axis]
1027
-
1028
- ishape=list(im.shape)
1029
- odata=1
1030
- for k in range(axis+1,len(ishape)):
1031
- odata=odata*ishape[k]
1032
-
1033
- ndata=1
1266
+ return self.backend.bk_reshape(
1267
+ res, ishape[0:axis] + [npix] + ishape[axis + 1 :]
1268
+ )
1269
+ return self.backend.bk_reshape(res, [npix])
1270
+
1271
+ # --------------------------------------------------------
1272
+ def smooth_1d(self, im, axis=0):
1273
+
1274
+ xx = np.arange(5) - 2
1275
+ w = np.exp(-0.17328679514 * (xx) ** 2)
1276
+ w = w / w.sum()
1277
+ w = np.array(w).reshape(xx.shape[0], 1, 1)
1278
+
1279
+ npix = im.shape[axis]
1280
+
1281
+ ishape = list(im.shape)
1282
+ odata = 1
1283
+ for k in range(axis + 1, len(ishape)):
1284
+ odata = odata * ishape[k]
1285
+
1286
+ ndata = 1
1034
1287
  for k in range(axis):
1035
- ndata=ndata*ishape[k]
1036
-
1037
- if odata>1:
1038
- w=np.repeat(w,odata,2)
1039
-
1040
- w=self.backend.bk_cast(self.backend.constant(w))
1041
-
1042
- tim = self.backend.bk_reshape(self.backend.bk_cast(im),[ndata,npix,odata])
1043
-
1044
- if tim.dtype==self.all_cbk_type:
1045
- rr = self.backend.bk_conv1d(self.backend.bk_real(tim),w)
1046
- ii = self.backend.bk_conv1d(self.backend.bk_real(tim),w)
1047
- res=self.backend.bk_complex(rr,ii)
1288
+ ndata = ndata * ishape[k]
1289
+
1290
+ if odata > 1:
1291
+ w = np.repeat(w, odata, 2)
1292
+
1293
+ w = self.backend.bk_cast(self.backend.constant(w))
1294
+
1295
+ tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata])
1296
+
1297
+ if tim.dtype == self.all_cbk_type:
1298
+ rr = self.backend.bk_conv1d(self.backend.bk_real(tim), w)
1299
+ ii = self.backend.bk_conv1d(self.backend.bk_real(tim), w)
1300
+ res = self.backend.bk_complex(rr, ii)
1048
1301
  else:
1049
- res=self.backend.bk_conv1d(tim,w)
1050
-
1051
- if axis==0:
1052
- if len(ishape)==1:
1053
- return self.backend.bk_reshape(res,[npix])
1302
+ res = self.backend.bk_conv1d(tim, w)
1303
+
1304
+ if axis == 0:
1305
+ if len(ishape) == 1:
1306
+ return self.backend.bk_reshape(res, [npix])
1054
1307
  else:
1055
- return self.backend.bk_reshape(res,[npix]+ishape[axis+1:])
1308
+ return self.backend.bk_reshape(res, [npix] + ishape[axis + 1 :])
1056
1309
  else:
1057
- if len(ishape)==axis+1:
1058
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix])
1310
+ if len(ishape) == axis + 1:
1311
+ return self.backend.bk_reshape(res, ishape[0:axis] + [npix])
1059
1312
  else:
1060
- return self.backend.bk_reshape(res,ishape[0:axis]+[npix]+ishape[axis+1:])
1061
- return self.backend.bk_reshape(res,[npix])
1062
-
1063
- #--------------------------------------------------------
1064
- def up_grade_1d(self,im,nout,axis=0):
1065
-
1066
- lout=int(im.shape[axis])
1067
- nscale=int(np.log(nout//lout)/np.log(2))
1068
- res=self.backend.bk_cast(im)
1313
+ return self.backend.bk_reshape(
1314
+ res, ishape[0:axis] + [npix] + ishape[axis + 1 :]
1315
+ )
1316
+ return self.backend.bk_reshape(res, [npix])
1317
+
1318
+ # --------------------------------------------------------
1319
+ def up_grade_1d(self, im, nout, axis=0):
1320
+
1321
+ lout = int(im.shape[axis])
1322
+ nscale = int(np.log(nout // lout) / np.log(2))
1323
+ res = self.backend.bk_cast(im)
1069
1324
  for k in range(nscale):
1070
- res=self.up_grade_2_1d(res,axis=axis)
1071
- return(res)
1072
-
1325
+ res = self.up_grade_2_1d(res, axis=axis)
1326
+ return res
1327
+
1073
1328
  # ---------------------------------------------−---------
1074
- def init_index(self,nside,kernel=-1):
1329
+ def init_index(self, nside, kernel=-1):
1075
1330
 
1076
- if kernel==-1:
1077
- l_kernel=self.KERNELSZ
1331
+ if kernel == -1:
1332
+ l_kernel = self.KERNELSZ
1078
1333
  else:
1079
- l_kernel=kernel
1080
-
1081
-
1334
+ l_kernel = kernel
1335
+
1082
1336
  try:
1083
1337
  if self.use_2D:
1084
- tmp=np.load('%s/W%d_%s_%d_IDX.npy'%(self.TEMPLATE_PATH,l_kernel**2,TMPFILE_VERSION,nside))
1338
+ tmp = np.load(
1339
+ "%s/W%d_%s_%d_IDX.npy"
1340
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1341
+ )
1085
1342
  else:
1086
- tmp=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel**2,self.NORIENT,nside))
1343
+ tmp = np.load(
1344
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1345
+ % (
1346
+ self.TEMPLATE_PATH,
1347
+ TMPFILE_VERSION,
1348
+ l_kernel**2,
1349
+ self.NORIENT,
1350
+ nside,
1351
+ )
1352
+ )
1087
1353
  except:
1088
- if self.use_2D==False:
1089
-
1090
- if l_kernel==5:
1091
- pw=0.5
1092
- pw2=0.5
1093
- threshold=2E-4
1094
-
1095
- elif l_kernel==3:
1096
- pw=1.0/np.sqrt(2)
1097
- pw2=1.0
1098
- threshold=1E-3
1099
-
1100
- elif l_kernel==7:
1101
- pw=0.5
1102
- pw2=0.25
1103
- threshold=4E-5
1104
-
1105
- th,ph=hp.pix2ang(nside,np.arange(12*nside**2),nest=True)
1106
- x,y,z=hp.pix2vec(nside,np.arange(12*nside**2),nest=True)
1107
-
1108
- t,p=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
1109
- phi=[p[k]/np.pi*180 for k in range(12*nside*nside)]
1110
- thi=[t[k]/np.pi*180 for k in range(12*nside*nside)]
1111
-
1112
-
1113
- indice2=np.zeros([12*nside*nside*64,2],dtype='int')
1114
- indice=np.zeros([12*nside*nside*64*self.NORIENT,2],dtype='int')
1115
- wav=np.zeros([12*nside*nside*64*self.NORIENT],dtype='complex')
1116
- wwav=np.zeros([12*nside*nside*64*self.NORIENT],dtype='float')
1117
-
1118
- iv=0
1119
- iv2=0
1120
- for iii in range(12*nside*nside):
1121
-
1122
- if iii%(nside*nside)==nside*nside-1:
1354
+ if not self.use_2D:
1355
+
1356
+ if l_kernel == 5:
1357
+ pw = 0.5
1358
+ pw2 = 0.5
1359
+ threshold = 2e-4
1360
+
1361
+ elif l_kernel == 3:
1362
+ pw = 1.0 / np.sqrt(2)
1363
+ pw2 = 1.0
1364
+ threshold = 1e-3
1365
+
1366
+ elif l_kernel == 7:
1367
+ pw = 0.5
1368
+ pw2 = 0.25
1369
+ threshold = 4e-5
1370
+
1371
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1372
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1373
+
1374
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1375
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1376
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1377
+
1378
+ indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1379
+ indice = np.zeros(
1380
+ [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1381
+ )
1382
+ wav = np.zeros(
1383
+ [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1384
+ )
1385
+ wwav = np.zeros([12 * nside * nside * 64 * self.NORIENT], dtype="float")
1386
+
1387
+ iv = 0
1388
+ iv2 = 0
1389
+ for iii in range(12 * nside * nside):
1390
+
1391
+ if iii % (nside * nside) == nside * nside - 1:
1123
1392
  if not self.silent:
1124
- print('Pre-compute nside=%6d %.2f%%'%(nside,100*iii/(12*nside*nside)))
1125
-
1126
- hidx=hp.query_disc(nside, [x[iii],y[iii],z[iii]], 2*np.pi/nside,nest=True)
1127
-
1128
- R=hp.Rotator(rot=[phi[iii],-thi[iii]],eulertype='ZYZ')
1129
-
1130
- t2,p2=R(th[hidx],ph[hidx])
1131
-
1132
- vec2=hp.ang2vec(t2,p2)
1133
-
1134
- x2=vec2[:,0]
1135
- y2=vec2[:,1]
1136
- z2=vec2[:,2]
1137
-
1138
- ww=np.exp(-pw2*((nside)**2)*((x2)**2+(y2)**2+(z2-1.0)**2))
1139
- idx=np.where((ww**2)>threshold)[0]
1140
- nval2=len(idx)
1141
- indice2[iv2:iv2+nval2,0]=iii
1142
- indice2[iv2:iv2+nval2,1]=hidx[idx]
1143
- wwav[iv2:iv2+nval2]=ww[idx]/np.sum(ww[idx])
1144
- iv2+=nval2
1145
-
1146
- for l in range(self.NORIENT):
1147
-
1148
- angle=l/4.*np.pi-phi[iii]/180.*np.pi*(z[hidx]>0)-(180.0-phi[iii])/180.*np.pi*(z[hidx]<0)
1149
-
1150
- #posi=2*(0.5-(z[hidx]<0))
1151
-
1152
- axes=y2*np.cos(angle)-x2*np.sin(angle)
1153
- wresr=ww*np.cos(pw*axes*(nside)*np.pi)
1154
- wresi=ww*np.sin(pw*axes*(nside)*np.pi)
1155
-
1156
- vnorm=(wresr*wresr+wresi*wresi)
1157
- idx=np.where(vnorm>threshold)[0]
1158
-
1159
- nval=len(idx)
1160
- indice[iv:iv+nval,0]=iii*4+l
1161
- indice[iv:iv+nval,1]=hidx[idx]
1162
- #print([hidx[k] for k in idx])
1163
- #print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1164
- normr=np.mean(wresr[idx])
1165
- normi=np.mean(wresi[idx])
1166
-
1167
- val=wresr[idx]-normr+1J*(wresi[idx]-normi)
1168
- val=val/abs(val).sum()
1169
-
1170
- wav[iv:iv+nval]=val
1171
- iv+=nval
1172
-
1173
- indice=indice[:iv,:]
1174
- wav=wav[:iv]
1175
- indice2=indice2[:iv2,:]
1176
- wwav=wwav[:iv2]
1393
+ print(
1394
+ "Pre-compute nside=%6d %.2f%%"
1395
+ % (nside, 100 * iii / (12 * nside * nside))
1396
+ )
1397
+
1398
+ hidx = hp.query_disc(
1399
+ nside, [x[iii], y[iii], z[iii]], 2 * np.pi / nside, nest=True
1400
+ )
1401
+
1402
+ R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1403
+
1404
+ t2, p2 = R(th[hidx], ph[hidx])
1405
+
1406
+ vec2 = hp.ang2vec(t2, p2)
1407
+
1408
+ x2 = vec2[:, 0]
1409
+ y2 = vec2[:, 1]
1410
+ z2 = vec2[:, 2]
1411
+
1412
+ ww = np.exp(
1413
+ -pw2
1414
+ * ((nside) ** 2)
1415
+ * ((x2) ** 2 + (y2) ** 2 + (z2 - 1.0) ** 2)
1416
+ )
1417
+ idx = np.where((ww**2) > threshold)[0]
1418
+ nval2 = len(idx)
1419
+ indice2[iv2 : iv2 + nval2, 0] = iii
1420
+ indice2[iv2 : iv2 + nval2, 1] = hidx[idx]
1421
+ wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1422
+ iv2 += nval2
1423
+
1424
+ for l_rotation in range(self.NORIENT):
1425
+
1426
+ angle = (
1427
+ l_rotation / 4.0 * np.pi
1428
+ - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1429
+ - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1430
+ )
1431
+
1432
+ # posi=2*(0.5-(z[hidx]<0))
1433
+
1434
+ axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1435
+ wresr = ww * np.cos(pw * axes * (nside) * np.pi)
1436
+ wresi = ww * np.sin(pw * axes * (nside) * np.pi)
1437
+
1438
+ vnorm = wresr * wresr + wresi * wresi
1439
+ idx = np.where(vnorm > threshold)[0]
1440
+
1441
+ nval = len(idx)
1442
+ indice[iv : iv + nval, 0] = iii * 4 + l_rotation
1443
+ indice[iv : iv + nval, 1] = hidx[idx]
1444
+ # print([hidx[k] for k in idx])
1445
+ # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1446
+ normr = np.mean(wresr[idx])
1447
+ normi = np.mean(wresi[idx])
1448
+
1449
+ val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1450
+ val = val / abs(val).sum()
1451
+
1452
+ wav[iv : iv + nval] = val
1453
+ iv += nval
1454
+
1455
+ indice = indice[:iv, :]
1456
+ wav = wav[:iv]
1457
+ indice2 = indice2[:iv2, :]
1458
+ wwav = wwav[:iv2]
1177
1459
  if not self.silent:
1178
- print('Kernel Size ',iv/(self.NORIENT*12*nside*nside))
1460
+ print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1179
1461
  """
1180
1462
  # OLD VERSION OLD VERSION OLD VERSION (3.0)
1181
1463
  if self.KERNELSZ*self.KERNELSZ>12*nside*nside:
1182
1464
  l_kernel=3
1183
-
1465
+
1184
1466
  aa=np.cos(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
1185
1467
  bb=np.sin(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
1186
1468
  x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
@@ -1199,21 +1481,21 @@ class FoCUS:
1199
1481
  pw=np.pi/4.0
1200
1482
  pw2=1/2
1201
1483
  amp=1.0
1202
-
1484
+
1203
1485
  if l_kernel==5:
1204
1486
  pw=np.pi/4.0
1205
1487
  pw2=1/2.25
1206
1488
  amp=1.0/9.2038
1207
-
1489
+
1208
1490
  elif l_kernel==3:
1209
1491
  pw=1.0/np.sqrt(2)
1210
1492
  pw2=1.0
1211
1493
  amp=1/8.45
1212
-
1494
+
1213
1495
  elif l_kernel==7:
1214
1496
  pw=np.pi/4.0
1215
1497
  pw2=1.0/3.0
1216
-
1498
+
1217
1499
  for k in range(12*nside*nside):
1218
1500
  if k%(nside*nside)==0:
1219
1501
  if not self.silent:
@@ -1223,12 +1505,12 @@ class FoCUS:
1223
1505
  lidx=np.concatenate([lidx,np.array([(k//(scale*scale))])],0)
1224
1506
  lidx=np.repeat(lidx*(scale*scale),(scale*scale))+ \
1225
1507
  np.tile(np.arange((scale*scale)),lidx.shape[0])
1226
-
1508
+
1227
1509
  delta=(x[lidx]-x[k])**2+(y[lidx]-y[k])**2+(z[lidx]-z[k])**2
1228
1510
  pidx=np.where(delta<(10)/(nside**2))[0]
1229
1511
  if len(pidx)<l_kernel**2:
1230
1512
  pidx=np.arange(delta.shape[0])
1231
-
1513
+
1232
1514
  w=np.exp(-pw2*delta[pidx]*(nside**2))
1233
1515
  pidx=pidx[np.argsort(-w)[0:l_kernel**2]]
1234
1516
  pidx=pidx[np.argsort(lidx[pidx])]
@@ -1240,16 +1522,16 @@ class FoCUS:
1240
1522
  r=hp.Rotator(rot=rot)
1241
1523
  ty,tx=r(to[iwav[k]],po[iwav[k]])
1242
1524
  ty=ty-np.pi/2
1243
-
1525
+
1244
1526
  xx=np.expand_dims(pw*nside*np.pi*tx/np.cos(ty),-1)
1245
1527
  yy=np.expand_dims(pw*nside*np.pi*ty,-1)
1246
-
1528
+
1247
1529
  wav[k,:,:]=(np.cos(xx*aa+yy*bb)+complex(0.0,1.0)*np.sin(xx*aa+yy*bb))*np.expand_dims(w,-1)
1248
-
1530
+
1249
1531
  wav=wav-np.expand_dims(np.mean(wav,1),1)
1250
1532
  wav=amp*wav/np.expand_dims(np.std(wav,1),1)
1251
1533
  wwav=wwav/np.expand_dims(np.sum(wwav,1),1)
1252
-
1534
+
1253
1535
  nk=l_kernel*l_kernel
1254
1536
  indice=np.zeros([12*nside*nside*nk*self.NORIENT,2],dtype='int')
1255
1537
  lidx=np.arange(self.NORIENT)
@@ -1261,7 +1543,7 @@ class FoCUS:
1261
1543
  for i in range(12*nside*nside):
1262
1544
  indice2[i*nk:i*nk+nk,0]=i
1263
1545
  indice2[i*nk:i*nk+nk,1]=iwav[i]
1264
-
1546
+
1265
1547
  w=np.zeros([12*nside*nside,wav.shape[2],wav.shape[1]],dtype='complex')
1266
1548
  for i in range(wav.shape[1]):
1267
1549
  for j in range(wav.shape[2]):
@@ -1270,721 +1552,1155 @@ class FoCUS:
1270
1552
  wwav=wwav.flatten()
1271
1553
  """
1272
1554
  if not self.silent:
1273
- print('Write FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
1274
- np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice)
1275
- np.save('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),wav)
1276
- np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice2)
1277
- np.save('%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),wwav)
1555
+ print(
1556
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1557
+ % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1558
+ )
1559
+ np.save(
1560
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1561
+ % (
1562
+ self.TEMPLATE_PATH,
1563
+ TMPFILE_VERSION,
1564
+ self.KERNELSZ**2,
1565
+ self.NORIENT,
1566
+ nside,
1567
+ ),
1568
+ indice,
1569
+ )
1570
+ np.save(
1571
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1572
+ % (
1573
+ self.TEMPLATE_PATH,
1574
+ TMPFILE_VERSION,
1575
+ self.KERNELSZ**2,
1576
+ self.NORIENT,
1577
+ nside,
1578
+ ),
1579
+ wav,
1580
+ )
1581
+ np.save(
1582
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1583
+ % (
1584
+ self.TEMPLATE_PATH,
1585
+ TMPFILE_VERSION,
1586
+ self.KERNELSZ**2,
1587
+ self.NORIENT,
1588
+ nside,
1589
+ ),
1590
+ indice2,
1591
+ )
1592
+ np.save(
1593
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1594
+ % (
1595
+ self.TEMPLATE_PATH,
1596
+ TMPFILE_VERSION,
1597
+ self.KERNELSZ**2,
1598
+ self.NORIENT,
1599
+ nside,
1600
+ ),
1601
+ wwav,
1602
+ )
1278
1603
  else:
1279
- if l_kernel**2==9:
1280
- if self.rank==0:
1604
+ if l_kernel**2 == 9:
1605
+ if self.rank == 0:
1281
1606
  self.comp_idx_w9(nside)
1282
- elif l_kernel**2==25:
1283
- if self.rank==0:
1607
+ elif l_kernel**2 == 25:
1608
+ if self.rank == 0:
1284
1609
  self.comp_idx_w25(nside)
1285
1610
  else:
1286
- if self.rank==0:
1611
+ if self.rank == 0:
1287
1612
  if not self.silent:
1288
- print('Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d'%(KERNELSZ,KERNELSZ))
1613
+ print(
1614
+ "Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d"
1615
+ % (self.KERNELSZ, self.KERNELSZ)
1616
+ )
1289
1617
  return None
1290
1618
 
1291
- self.barrier()
1292
- if self.use_2D:
1293
- tmp=np.load('%s/W%d_%s_%d_IDX.npy'%(self.TEMPLATE_PATH,l_kernel**2,TMPFILE_VERSION,nside))
1619
+ self.barrier()
1620
+ if self.use_2D:
1621
+ tmp = np.load(
1622
+ "%s/W%d_%s_%d_IDX.npy"
1623
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1624
+ )
1294
1625
  else:
1295
- tmp=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
1296
- tmp2=np.load('%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
1297
- wr=np.load('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside)).real
1298
- wi=np.load('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside)).imag
1299
- ws=self.slope*np.load('%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
1300
-
1301
- wr=self.backend.bk_SparseTensor(self.backend.constant(tmp),self.backend.constant(self.backend.bk_cast(wr)),dense_shape=[12*nside**2*self.NORIENT,12*nside**2])
1302
- wi=self.backend.bk_SparseTensor(self.backend.constant(tmp),self.backend.constant(self.backend.bk_cast(wi)),dense_shape=[12*nside**2*self.NORIENT,12*nside**2])
1303
- ws=self.backend.bk_SparseTensor(self.backend.constant(tmp2),self.backend.constant(self.backend.bk_cast(ws)),dense_shape=[12*nside**2,12*nside**2])
1304
-
1305
- if kernel==-1:
1306
- self.Idx_Neighbours[nside]=tmp
1307
-
1626
+ tmp = np.load(
1627
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1628
+ % (
1629
+ self.TEMPLATE_PATH,
1630
+ TMPFILE_VERSION,
1631
+ self.KERNELSZ**2,
1632
+ self.NORIENT,
1633
+ nside,
1634
+ )
1635
+ )
1636
+ tmp2 = np.load(
1637
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1638
+ % (
1639
+ self.TEMPLATE_PATH,
1640
+ TMPFILE_VERSION,
1641
+ self.KERNELSZ**2,
1642
+ self.NORIENT,
1643
+ nside,
1644
+ )
1645
+ )
1646
+ wr = np.load(
1647
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1648
+ % (
1649
+ self.TEMPLATE_PATH,
1650
+ TMPFILE_VERSION,
1651
+ self.KERNELSZ**2,
1652
+ self.NORIENT,
1653
+ nside,
1654
+ )
1655
+ ).real
1656
+ wi = np.load(
1657
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1658
+ % (
1659
+ self.TEMPLATE_PATH,
1660
+ TMPFILE_VERSION,
1661
+ self.KERNELSZ**2,
1662
+ self.NORIENT,
1663
+ nside,
1664
+ )
1665
+ ).imag
1666
+ ws = self.slope * np.load(
1667
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1668
+ % (
1669
+ self.TEMPLATE_PATH,
1670
+ TMPFILE_VERSION,
1671
+ self.KERNELSZ**2,
1672
+ self.NORIENT,
1673
+ nside,
1674
+ )
1675
+ )
1676
+
1677
+ wr = self.backend.bk_SparseTensor(
1678
+ self.backend.constant(tmp),
1679
+ self.backend.constant(self.backend.bk_cast(wr)),
1680
+ dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1681
+ )
1682
+ wi = self.backend.bk_SparseTensor(
1683
+ self.backend.constant(tmp),
1684
+ self.backend.constant(self.backend.bk_cast(wi)),
1685
+ dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1686
+ )
1687
+ ws = self.backend.bk_SparseTensor(
1688
+ self.backend.constant(tmp2),
1689
+ self.backend.constant(self.backend.bk_cast(ws)),
1690
+ dense_shape=[12 * nside**2, 12 * nside**2],
1691
+ )
1692
+
1693
+ if kernel == -1:
1694
+ self.Idx_Neighbours[nside] = tmp
1695
+
1308
1696
  if self.use_2D:
1309
- if kernel!=-1:
1697
+ if kernel != -1:
1310
1698
  return tmp
1311
-
1312
- return wr,wi,ws,tmp
1313
1699
 
1314
-
1315
- # ---------------------------------------------−---------
1316
- # Compute x [....,a,....] to [....,a*a,....]
1317
- #NOT YET TESTED OR IMPLEMENTED
1318
- def auto_cross_2(x,axis=0):
1319
- shape=np.array(x.shape)
1320
- if axis==0:
1321
- y1=self.reshape(x,[shape[0],1,np.cumprod(shape[1:])])
1322
- y2=self.reshape(x,[1,shape[0],np.cumprod(shape[1:])])
1323
- oshape=np.concat([shape[0],shape[0],shape[1:]])
1324
- return(self.reshape(y1*y2,oshape))
1325
-
1326
- # ---------------------------------------------−---------
1327
- # Compute x [....,a,....,b,....] to [....,b*b,....,a*a,....]
1328
- #NOT YET TESTED OR IMPLEMENTED
1329
- def auto_cross_2(x,axis1=0,axis2=1):
1330
- shape=np.array(x.shape)
1331
- if axis==0:
1332
- y1=self.reshape(x,[shape[0],1,np.cumprod(shape[1:])])
1333
- y2=self.reshape(x,[1,shape[0],np.cumprod(shape[1:])])
1334
- oshape=np.concat([shape[0],shape[0],shape[1:]])
1335
- return(self.reshape(y1*y2,oshape))
1336
-
1337
-
1700
+ return wr, wi, ws, tmp
1701
+
1338
1702
  # ---------------------------------------------−---------
1339
1703
  # convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
1340
- def swapaxes(self,x,axis1,axis2):
1341
- shape=list(x.shape)
1342
- if axis1<0:
1343
- laxis1=len(shape)+axis1
1704
+ def swapaxes(self, x, axis1, axis2):
1705
+ shape = list(x.shape)
1706
+ if axis1 < 0:
1707
+ laxis1 = len(shape) + axis1
1344
1708
  else:
1345
- laxis1=axis1
1346
- if axis2<0:
1347
- laxis2=len(shape)+axis2
1709
+ laxis1 = axis1
1710
+ if axis2 < 0:
1711
+ laxis2 = len(shape) + axis2
1348
1712
  else:
1349
- laxis2=axis2
1350
-
1351
- naxes=len(shape)
1352
- thelist=[i for i in range(naxes)]
1353
- thelist[laxis1]=laxis2
1354
- thelist[laxis2]=laxis1
1355
- return self.backend.bk_transpose(x,thelist)
1356
-
1713
+ laxis2 = axis2
1714
+
1715
+ naxes = len(shape)
1716
+ thelist = [i for i in range(naxes)]
1717
+ thelist[laxis1] = laxis2
1718
+ thelist[laxis2] = laxis1
1719
+ return self.backend.bk_transpose(x, thelist)
1720
+
1357
1721
  # ---------------------------------------------−---------
1358
1722
  # Mean using mask x [....,Npix,....], mask[Nmask,Npix] to [....,Nmask,....]
1359
1723
  # if use_2D
1360
1724
  # Mean using mask x [....,12,Nside+2*off,Nside+2*off,....], mask[Nmask,12,Nside+2*off,Nside+2*off] to [....,Nmask,....]
1361
- def masked_mean(self,x,mask,axis=0,rank=0,calc_var=False):
1362
-
1363
- #==========================================================================
1725
+ def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
1726
+
1727
+ # ==========================================================================
1364
1728
  # in input data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]]
1365
1729
  # in input mask=[Nmask,X[,Y]]
1366
1730
  # if self.use_2D : X[,Y]] = [X,Y]
1367
1731
  # if second level: NORIENT[,NORIENT]= NORIENT,NORIENT
1368
- #==========================================================================
1369
-
1370
- shape=list(x.shape)
1371
-
1732
+ # ==========================================================================
1733
+
1734
+ shape = list(x.shape)
1735
+
1372
1736
  if not self.use_2D:
1373
- nside=int(np.sqrt(x.shape[axis]//12))
1374
-
1375
- l_mask=mask
1737
+ nside = int(np.sqrt(x.shape[axis] // 12))
1738
+
1739
+ l_mask = mask
1376
1740
  if self.mask_norm:
1377
- sum_mask=self.backend.bk_reduce_sum(self.backend.bk_reshape(l_mask,[l_mask.shape[0],np.prod(np.array(l_mask.shape[1:]))]),1)
1741
+ sum_mask = self.backend.bk_reduce_sum(
1742
+ self.backend.bk_reshape(
1743
+ l_mask, [l_mask.shape[0], np.prod(np.array(l_mask.shape[1:]))]
1744
+ ),
1745
+ 1,
1746
+ )
1378
1747
  if not self.use_2D:
1379
- l_mask=12*nside*nside*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1380
- elif self.use_2D:
1381
- l_mask=mask.shape[1]*mask.shape[2]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1748
+ l_mask = (
1749
+ 12
1750
+ * nside
1751
+ * nside
1752
+ * l_mask
1753
+ / self.backend.bk_reshape(
1754
+ sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]]
1755
+ )
1756
+ )
1757
+ elif self.use_2D:
1758
+ l_mask = (
1759
+ mask.shape[1]
1760
+ * mask.shape[2]
1761
+ * l_mask
1762
+ / self.backend.bk_reshape(
1763
+ sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]]
1764
+ )
1765
+ )
1382
1766
  else:
1383
- l_mask=mask.shape[1]*l_mask/self.backend.bk_reshape(sum_mask,[l_mask.shape[0]]+[1 for i in l_mask.shape[1:]])
1767
+ l_mask = (
1768
+ mask.shape[1]
1769
+ * l_mask
1770
+ / self.backend.bk_reshape(
1771
+ sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]]
1772
+ )
1773
+ )
1384
1774
 
1385
1775
  if self.use_2D:
1386
- if self.padding=='VALID':
1387
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1388
- if shape[axis]!=l_mask.shape[1]:
1389
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1390
-
1391
- ichannel=1
1776
+ if self.padding == "VALID":
1777
+ l_mask = l_mask[
1778
+ :,
1779
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1780
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1781
+ ]
1782
+ if shape[axis] != l_mask.shape[1]:
1783
+ l_mask = l_mask[
1784
+ :,
1785
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1786
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1787
+ ]
1788
+
1789
+ ichannel = 1
1392
1790
  for i in range(axis):
1393
- ichannel*=shape[i]
1394
- ochannel=1
1395
- for i in range(axis+2,len(shape)):
1396
- ochannel*=shape[i]
1397
- l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],shape[axis+1],ochannel])
1398
-
1399
- if self.padding=='VALID':
1400
- oshape=[k for k in shape]
1401
- oshape[axis]=oshape[axis]-self.KERNELSZ+1
1402
- oshape[axis+1]=oshape[axis+1]-self.KERNELSZ+1
1403
- l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1404
-
1791
+ ichannel *= shape[i]
1792
+ ochannel = 1
1793
+ for i in range(axis + 2, len(shape)):
1794
+ ochannel *= shape[i]
1795
+ l_x = self.backend.bk_reshape(
1796
+ x, [ichannel, 1, shape[axis], shape[axis + 1], ochannel]
1797
+ )
1798
+
1799
+ if self.padding == "VALID":
1800
+ oshape = [k for k in shape]
1801
+ oshape[axis] = oshape[axis] - self.KERNELSZ + 1
1802
+ oshape[axis + 1] = oshape[axis + 1] - self.KERNELSZ + 1
1803
+ l_x = self.backend.bk_reshape(
1804
+ l_x[
1805
+ :,
1806
+ :,
1807
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1808
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1809
+ :,
1810
+ ],
1811
+ oshape,
1812
+ )
1813
+
1405
1814
  elif self.use_1D:
1406
- if self.padding=='VALID':
1407
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1408
- if shape[axis]!=l_mask.shape[1]:
1409
- l_mask=l_mask[:,self.KERNELSZ//2:-self.KERNELSZ//2+1]
1410
-
1411
- ichannel=1
1815
+ if self.padding == "VALID":
1816
+ l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1]
1817
+ if shape[axis] != l_mask.shape[1]:
1818
+ l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1]
1819
+
1820
+ ichannel = 1
1412
1821
  for i in range(axis):
1413
- ichannel*=shape[i]
1414
- ochannel=1
1415
- for i in range(axis+1,len(shape)):
1416
- ochannel*=shape[i]
1417
- l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1418
-
1419
- if self.padding=='VALID':
1420
- oshape=[k for k in shape]
1421
- oshape[axis]=oshape[axis]-self.KERNELSZ+1
1422
- l_x=self.backend.bk_reshape(l_x[:,:,self.KERNELSZ//2:-self.KERNELSZ//2+1,:],oshape)
1822
+ ichannel *= shape[i]
1823
+ ochannel = 1
1824
+ for i in range(axis + 1, len(shape)):
1825
+ ochannel *= shape[i]
1826
+ l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1827
+
1828
+ if self.padding == "VALID":
1829
+ oshape = [k for k in shape]
1830
+ oshape[axis] = oshape[axis] - self.KERNELSZ + 1
1831
+ l_x = self.backend.bk_reshape(
1832
+ l_x[:, :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, :], oshape
1833
+ )
1423
1834
  else:
1424
- ichannel=1
1835
+ ichannel = 1
1425
1836
  for i in range(axis):
1426
- ichannel*=shape[i]
1427
- ochannel=1
1428
- for i in range(axis+1,len(shape)):
1429
- ochannel*=shape[i]
1430
- l_x=self.backend.bk_reshape(x,[ichannel,1,shape[axis],ochannel])
1837
+ ichannel *= shape[i]
1838
+ ochannel = 1
1839
+ for i in range(axis + 1, len(shape)):
1840
+ ochannel *= shape[i]
1841
+ l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1431
1842
 
1432
1843
  # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,1,...,X[,Y],NORIENT[,NORIENT]]
1433
1844
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,X[,Y]]
1434
- l_mask=self.backend.bk_expand_dims(l_mask,0)
1845
+ l_mask = self.backend.bk_expand_dims(l_mask, 0)
1435
1846
  # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,X[,Y],1]
1436
- l_mask=self.backend.bk_expand_dims(l_mask,-1)
1437
-
1438
- if l_x.dtype==self.all_cbk_type:
1439
- l_mask=self.backend.bk_complex(l_mask,self.backend.bk_cast(0.0*l_mask))
1440
-
1847
+ l_mask = self.backend.bk_expand_dims(l_mask, -1)
1848
+
1849
+ if l_x.dtype == self.all_cbk_type:
1850
+ l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
1851
+
1441
1852
  if self.use_2D:
1442
- mtmp=l_mask
1443
- vtmp=l_x
1444
-
1445
- v1=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp,axis=2),2)
1446
- v2=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2),2)
1447
- vh=self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp,axis=2),2)
1448
-
1449
- res=v1/vh
1450
-
1451
- oshape=[]
1452
- if axis>0:
1453
- oshape=oshape+list(x.shape[0:axis])
1454
- oshape=oshape+[mask.shape[0]]
1455
- if axis+1<len(x.shape):
1456
- oshape=oshape+list(x.shape[axis+2:])
1457
-
1853
+ mtmp = l_mask
1854
+ vtmp = l_x
1855
+
1856
+ v1 = self.backend.bk_reduce_sum(
1857
+ self.backend.bk_reduce_sum(mtmp * vtmp, axis=2), 2
1858
+ )
1859
+ v2 = self.backend.bk_reduce_sum(
1860
+ self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2), 2
1861
+ )
1862
+ vh = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp, axis=2), 2)
1863
+
1864
+ res = v1 / vh
1865
+
1866
+ oshape = []
1867
+ if axis > 0:
1868
+ oshape = oshape + list(x.shape[0:axis])
1869
+ oshape = oshape + [mask.shape[0]]
1870
+ if axis + 1 < len(x.shape):
1871
+ oshape = oshape + list(x.shape[axis + 2 :])
1872
+
1458
1873
  if calc_var:
1459
1874
  if self.backend.bk_is_complex(vtmp):
1460
- res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1461
- -self.backend.bk_real(res)*self.backend.bk_real(res)) + \
1462
- (self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
1463
- -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1875
+ res2 = self.backend.bk_sqrt(
1876
+ (
1877
+ (
1878
+ self.backend.bk_real(v2) / self.backend.bk_real(vh)
1879
+ - self.backend.bk_real(res) * self.backend.bk_real(res)
1880
+ )
1881
+ + (
1882
+ self.backend.bk_imag(v2) / self.backend.bk_real(vh)
1883
+ - self.backend.bk_imag(res) * self.backend.bk_imag(res)
1884
+ )
1885
+ )
1886
+ / self.backend.bk_real(vh)
1887
+ )
1464
1888
  else:
1465
- res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1889
+ res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1466
1890
 
1467
- res=self.backend.bk_reshape(res,oshape)
1468
- res2=self.backend.bk_reshape(res2,oshape)
1469
- return res,res2
1891
+ res = self.backend.bk_reshape(res, oshape)
1892
+ res2 = self.backend.bk_reshape(res2, oshape)
1893
+ return res, res2
1470
1894
  else:
1471
- res=self.backend.bk_reshape(res,oshape)
1895
+ res = self.backend.bk_reshape(res, oshape)
1472
1896
  return res
1473
-
1897
+
1474
1898
  elif self.use_1D:
1475
- mtmp=l_mask
1476
- vtmp=l_x
1477
-
1478
- v1=self.backend.bk_reduce_sum(mtmp*vtmp,axis=2)
1479
- v2=self.backend.bk_reduce_sum(mtmp*vtmp*vtmp,axis=2)
1480
- vh=self.backend.bk_reduce_sum(mtmp,axis=2)
1481
-
1482
- res=v1/vh
1483
-
1484
- oshape=[]
1485
- if axis>0:
1486
- oshape=oshape+list(x.shape[0:axis])
1487
- oshape=oshape+[mask.shape[0]]
1488
- if axis+1<len(x.shape):
1489
- oshape=oshape+list(x.shape[axis+1:])
1490
-
1899
+ mtmp = l_mask
1900
+ vtmp = l_x
1901
+
1902
+ v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=2)
1903
+ v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2)
1904
+ vh = self.backend.bk_reduce_sum(mtmp, axis=2)
1905
+
1906
+ res = v1 / vh
1907
+
1908
+ oshape = []
1909
+ if axis > 0:
1910
+ oshape = oshape + list(x.shape[0:axis])
1911
+ oshape = oshape + [mask.shape[0]]
1912
+ if axis + 1 < len(x.shape):
1913
+ oshape = oshape + list(x.shape[axis + 1 :])
1914
+
1491
1915
  if calc_var:
1492
1916
  if self.backend.bk_is_complex(vtmp):
1493
- res2=self.backend.bk_sqrt(((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1494
- -self.backend.bk_real(res)*self.backend.bk_real(res)) + \
1495
- (self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
1496
- -self.backend.bk_imag(res)*self.backend.bk_imag(res)))/self.backend.bk_real(vh))
1917
+ res2 = self.backend.bk_sqrt(
1918
+ (
1919
+ (
1920
+ self.backend.bk_real(v2) / self.backend.bk_real(vh)
1921
+ - self.backend.bk_real(res) * self.backend.bk_real(res)
1922
+ )
1923
+ + (
1924
+ self.backend.bk_imag(v2) / self.backend.bk_real(vh)
1925
+ - self.backend.bk_imag(res) * self.backend.bk_imag(res)
1926
+ )
1927
+ )
1928
+ / self.backend.bk_real(vh)
1929
+ )
1497
1930
  else:
1498
- res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1931
+ res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1499
1932
 
1500
-
1501
- res=self.backend.bk_reshape(res,oshape)
1502
- res2=self.backend.bk_reshape(res2,oshape)
1503
- return res,res2
1933
+ res = self.backend.bk_reshape(res, oshape)
1934
+ res2 = self.backend.bk_reshape(res2, oshape)
1935
+ return res, res2
1504
1936
  else:
1505
- res=self.backend.bk_reshape(res,oshape)
1937
+ res = self.backend.bk_reshape(res, oshape)
1506
1938
  return res
1507
-
1508
- else:
1509
- v1=self.backend.bk_reduce_sum(l_mask*l_x,axis=2)
1510
- v2=self.backend.bk_reduce_sum(l_mask*l_x*l_x,axis=2)
1511
- vh=self.backend.bk_reduce_sum(l_mask,axis=2)
1512
-
1513
- res=v1/vh
1514
-
1515
- oshape=[]
1516
- if axis>0:
1517
- oshape=oshape+list(x.shape[0:axis])
1518
- oshape=oshape+[mask.shape[0]]
1519
- if axis+1<len(x.shape):
1520
- oshape=oshape+list(x.shape[axis+1:])
1521
-
1939
+
1940
+ else:
1941
+ v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=2)
1942
+ v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=2)
1943
+ vh = self.backend.bk_reduce_sum(l_mask, axis=2)
1944
+
1945
+ res = v1 / vh
1946
+
1947
+ oshape = []
1948
+ if axis > 0:
1949
+ oshape = oshape + list(x.shape[0:axis])
1950
+ oshape = oshape + [mask.shape[0]]
1951
+ if axis + 1 < len(x.shape):
1952
+ oshape = oshape + list(x.shape[axis + 1 :])
1953
+
1522
1954
  if calc_var:
1523
1955
  if self.backend.bk_is_complex(l_x):
1524
- res2=self.backend.bk_sqrt((self.backend.bk_real(v2)/self.backend.bk_real(vh)
1525
- -self.backend.bk_real(res)*self.backend.bk_real(res) + \
1526
- self.backend.bk_imag(v2)/self.backend.bk_real(vh) \
1527
- -self.backend.bk_imag(res)*self.backend.bk_imag(res))/self.backend.bk_real(vh))
1956
+ res2 = self.backend.bk_sqrt(
1957
+ (
1958
+ self.backend.bk_real(v2) / self.backend.bk_real(vh)
1959
+ - self.backend.bk_real(res) * self.backend.bk_real(res)
1960
+ + self.backend.bk_imag(v2) / self.backend.bk_real(vh)
1961
+ - self.backend.bk_imag(res) * self.backend.bk_imag(res)
1962
+ )
1963
+ / self.backend.bk_real(vh)
1964
+ )
1528
1965
  else:
1529
- res2=self.backend.bk_sqrt((v2/vh-res*res)/(vh))
1530
-
1531
- res=self.backend.bk_reshape(res,oshape)
1532
- res2=self.backend.bk_reshape(res2,oshape)
1533
- return res,res2
1966
+ res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1967
+
1968
+ res = self.backend.bk_reshape(res, oshape)
1969
+ res2 = self.backend.bk_reshape(res2, oshape)
1970
+ return res, res2
1534
1971
  else:
1535
- res=self.backend.bk_reshape(res,oshape)
1972
+ res = self.backend.bk_reshape(res, oshape)
1536
1973
  return res
1537
-
1974
+
1538
1975
  # ---------------------------------------------−---------
1539
1976
  # convert tensor x [....,a,b,....] to [....,a*b,....]
1540
- def reduce_dim(self,x,axis=0):
1541
- shape=list(x.shape)
1542
-
1543
- if axis<0:
1544
- laxis=len(shape)+axis
1977
+ def reduce_dim(self, x, axis=0):
1978
+ shape = list(x.shape)
1979
+
1980
+ if axis < 0:
1981
+ laxis = len(shape) + axis
1545
1982
  else:
1546
- laxis=axis
1547
-
1548
- if laxis>0 :
1549
- oshape=shape[0:laxis]
1550
- oshape.append(shape[laxis]*shape[laxis+1])
1983
+ laxis = axis
1984
+
1985
+ if laxis > 0:
1986
+ oshape = shape[0:laxis]
1987
+ oshape.append(shape[laxis] * shape[laxis + 1])
1551
1988
  else:
1552
- oshape=[shape[laxis]*shape[laxis+1]]
1553
-
1554
- if laxis<len(shape)-1:
1555
- oshape.extend(shape[laxis+2:])
1556
-
1557
- return(self.backend.bk_reshape(x,oshape))
1558
-
1559
-
1989
+ oshape = [shape[laxis] * shape[laxis + 1]]
1990
+
1991
+ if laxis < len(shape) - 1:
1992
+ oshape.extend(shape[laxis + 2 :])
1993
+
1994
+ return self.backend.bk_reshape(x, oshape)
1995
+
1560
1996
  # ---------------------------------------------−---------
1561
- def conv2d(self,image,ww,axis=0):
1997
+ def conv2d(self, image, ww, axis=0):
1562
1998
 
1563
- if len(ww.shape)==2:
1564
- norient=ww.shape[1]
1999
+ if len(ww.shape) == 2:
2000
+ norient = ww.shape[1]
1565
2001
  else:
1566
- norient=ww.shape[2]
2002
+ norient = ww.shape[2]
1567
2003
 
1568
- shape=image.shape
2004
+ shape = image.shape
1569
2005
 
1570
- if axis>0:
1571
- o_shape=shape[0]
1572
- for k in range(1,axis+1):
1573
- o_shape=o_shape*shape[k]
2006
+ if axis > 0:
2007
+ o_shape = shape[0]
2008
+ for k in range(1, axis + 1):
2009
+ o_shape = o_shape * shape[k]
1574
2010
  else:
1575
- o_shape=image.shape[0]
1576
-
1577
- if len(shape)>axis+3:
1578
- ishape=shape[axis+3]
1579
- for k in range(axis+4,len(shape)):
1580
- ishape=ishape*shape[k]
1581
-
1582
- oshape=[o_shape,shape[axis+1],shape[axis+2],ishape]
1583
-
1584
- #l_image=self.swapaxes(self.bk_reshape(image,oshape),-1,-3)
1585
- l_image=self.backend.bk_reshape(image,oshape)
1586
-
1587
- l_ww=np.zeros([self.KERNELSZ,self.KERNELSZ,ishape,ishape*norient])
2011
+ o_shape = image.shape[0]
2012
+
2013
+ if len(shape) > axis + 3:
2014
+ ishape = shape[axis + 3]
2015
+ for k in range(axis + 4, len(shape)):
2016
+ ishape = ishape * shape[k]
2017
+
2018
+ oshape = [o_shape, shape[axis + 1], shape[axis + 2], ishape]
2019
+
2020
+ # l_image=self.swapaxes(self.bk_reshape(image,oshape),-1,-3)
2021
+ l_image = self.backend.bk_reshape(image, oshape)
2022
+
2023
+ l_ww = np.zeros([self.KERNELSZ, self.KERNELSZ, ishape, ishape * norient])
1588
2024
  for k in range(ishape):
1589
- l_ww[:,:,k,k*norient:(k+1)*norient]=ww.reshape(self.KERNELSZ,self.KERNELSZ,norient)
1590
-
2025
+ l_ww[:, :, k, k * norient : (k + 1) * norient] = ww.reshape(
2026
+ self.KERNELSZ, self.KERNELSZ, norient
2027
+ )
2028
+
1591
2029
  if self.backend.bk_is_complex(l_image):
1592
- r=self.backend.conv2d(self.backend.bk_real(l_image),
1593
- l_ww,
1594
- strides=[1, 1, 1, 1],
1595
- padding=self.padding)
1596
- i=self.backend.conv2d(self.backend.bk_imag(l_image),
1597
- l_ww,
1598
- strides=[1, 1, 1, 1],
1599
- padding=self.padding)
1600
- res=self.backend.bk_complex(r,i)
2030
+ r = self.backend.conv2d(
2031
+ self.backend.bk_real(l_image),
2032
+ l_ww,
2033
+ strides=[1, 1, 1, 1],
2034
+ padding=self.padding,
2035
+ )
2036
+ i = self.backend.conv2d(
2037
+ self.backend.bk_imag(l_image),
2038
+ l_ww,
2039
+ strides=[1, 1, 1, 1],
2040
+ padding=self.padding,
2041
+ )
2042
+ res = self.backend.bk_complex(r, i)
1601
2043
  else:
1602
- res=self.backend.conv2d(l_image,l_ww,strides=[1, 1, 1, 1],padding=self.padding)
2044
+ res = self.backend.conv2d(
2045
+ l_image, l_ww, strides=[1, 1, 1, 1], padding=self.padding
2046
+ )
1603
2047
 
1604
- res=self.backend.bk_reshape(res,[o_shape,shape[axis+1],shape[axis+2],ishape,norient])
2048
+ res = self.backend.bk_reshape(
2049
+ res, [o_shape, shape[axis + 1], shape[axis + 2], ishape, norient]
2050
+ )
1605
2051
  else:
1606
- oshape=[o_shape,shape[axis+1],shape[axis+2],1]
1607
- l_ww=self.backend.bk_reshape(ww,[self.KERNELSZ,self.KERNELSZ,1,norient])
2052
+ oshape = [o_shape, shape[axis + 1], shape[axis + 2], 1]
2053
+ l_ww = self.backend.bk_reshape(
2054
+ ww, [self.KERNELSZ, self.KERNELSZ, 1, norient]
2055
+ )
1608
2056
 
1609
- tmp=self.backend.bk_reshape(image,oshape)
2057
+ tmp = self.backend.bk_reshape(image, oshape)
1610
2058
  if self.backend.bk_is_complex(tmp):
1611
- r=self.backend.conv2d(self.backend.bk_real(tmp),
1612
- l_ww,
1613
- strides=[1, 1, 1, 1],
1614
- padding=self.padding)
1615
- i=self.backend.conv2d(self.backend.bk_imag(tmp),
1616
- l_ww,
1617
- strides=[1, 1, 1, 1],
1618
- padding=self.padding)
1619
- res=self.backend.bk_complex(r,i)
2059
+ r = self.backend.conv2d(
2060
+ self.backend.bk_real(tmp),
2061
+ l_ww,
2062
+ strides=[1, 1, 1, 1],
2063
+ padding=self.padding,
2064
+ )
2065
+ i = self.backend.conv2d(
2066
+ self.backend.bk_imag(tmp),
2067
+ l_ww,
2068
+ strides=[1, 1, 1, 1],
2069
+ padding=self.padding,
2070
+ )
2071
+ res = self.backend.bk_complex(r, i)
1620
2072
  else:
1621
- res=self.backend.conv2d(tmp,
1622
- l_ww,
1623
- strides=[1, 1, 1, 1],
1624
- padding=self.padding)
2073
+ res = self.backend.conv2d(
2074
+ tmp, l_ww, strides=[1, 1, 1, 1], padding=self.padding
2075
+ )
1625
2076
 
1626
- return self.backend.bk_reshape(res,shape+[norient])
2077
+ return self.backend.bk_reshape(res, shape + [norient])
1627
2078
 
1628
- def diff_data(self,x,y,is_complex=True,sigma=None):
2079
+ def diff_data(self, x, y, is_complex=True, sigma=None):
1629
2080
  if sigma is None:
1630
2081
  if self.backend.bk_is_complex(x):
1631
- r=self.backend.bk_square(self.backend.bk_real(x)-self.backend.bk_real(y))
1632
- i=self.backend.bk_square(self.backend.bk_imag(x)-self.backend.bk_imag(y))
1633
- return self.backend.bk_reduce_sum(r+i)
2082
+ r = self.backend.bk_square(
2083
+ self.backend.bk_real(x) - self.backend.bk_real(y)
2084
+ )
2085
+ i = self.backend.bk_square(
2086
+ self.backend.bk_imag(x) - self.backend.bk_imag(y)
2087
+ )
2088
+ return self.backend.bk_reduce_sum(r + i)
1634
2089
  else:
1635
- r=self.backend.bk_square(x-y)
2090
+ r = self.backend.bk_square(x - y)
1636
2091
  return self.backend.bk_reduce_sum(r)
1637
2092
  else:
1638
2093
  if self.backend.bk_is_complex(x):
1639
- r=self.backend.bk_square((self.backend.bk_real(x)-self.backend.bk_real(y))/sigma)
1640
- i=self.backend.bk_square((self.backend.bk_imag(x)-self.backend.bk_imag(y))/sigma)
1641
- return self.backend.bk_reduce_sum(r+i)
2094
+ r = self.backend.bk_square(
2095
+ (self.backend.bk_real(x) - self.backend.bk_real(y)) / sigma
2096
+ )
2097
+ i = self.backend.bk_square(
2098
+ (self.backend.bk_imag(x) - self.backend.bk_imag(y)) / sigma
2099
+ )
2100
+ return self.backend.bk_reduce_sum(r + i)
1642
2101
  else:
1643
- r=self.backend.bk_square((x-y)/sigma)
2102
+ r = self.backend.bk_square((x - y) / sigma)
1644
2103
  return self.backend.bk_reduce_sum(r)
1645
-
2104
+
1646
2105
  # ---------------------------------------------−---------
1647
- def convol(self,in_image,axis=0):
2106
+ def convol(self, in_image, axis=0):
2107
+
2108
+ image = self.backend.bk_cast(in_image)
1648
2109
 
1649
- image=self.backend.bk_cast(in_image)
1650
-
1651
2110
  if self.use_2D:
1652
- ishape=list(in_image.shape)
1653
- if len(ishape)<axis+2:
2111
+ ishape = list(in_image.shape)
2112
+ if len(ishape) < axis + 2:
1654
2113
  if not self.silent:
1655
- print('Use of 2D scat with data that has less than 2D')
2114
+ print("Use of 2D scat with data that has less than 2D")
1656
2115
  return None
1657
-
1658
- npix=ishape[axis]
1659
- npiy=ishape[axis+1]
1660
- odata=1
1661
- if len(ishape)>axis+2:
1662
- for k in range(axis+2,len(ishape)):
1663
- odata=odata*ishape[k]
1664
-
1665
- ndata=1
2116
+
2117
+ npix = ishape[axis]
2118
+ npiy = ishape[axis + 1]
2119
+ odata = 1
2120
+ if len(ishape) > axis + 2:
2121
+ for k in range(axis + 2, len(ishape)):
2122
+ odata = odata * ishape[k]
2123
+
2124
+ ndata = 1
1666
2125
  for k in range(axis):
1667
- ndata=ndata*ishape[k]
2126
+ ndata = ndata * ishape[k]
1668
2127
 
1669
- tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,npiy,odata])
2128
+ tim = self.backend.bk_reshape(
2129
+ self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2130
+ )
1670
2131
 
1671
2132
  if self.backend.bk_is_complex(tim):
1672
- rr1=self.backend.conv2d(self.backend.bk_real(tim),self.ww_RealT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1673
- ii1=self.backend.conv2d(self.backend.bk_real(tim),self.ww_ImagT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1674
- rr2=self.backend.conv2d(self.backend.bk_imag(tim),self.ww_RealT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1675
- ii2=self.backend.conv2d(self.backend.bk_imag(tim),self.ww_ImagT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1676
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2133
+ rr1 = self.backend.conv2d(
2134
+ self.backend.bk_real(tim),
2135
+ self.ww_RealT[odata],
2136
+ strides=[1, 1, 1, 1],
2137
+ padding=self.padding,
2138
+ )
2139
+ ii1 = self.backend.conv2d(
2140
+ self.backend.bk_real(tim),
2141
+ self.ww_ImagT[odata],
2142
+ strides=[1, 1, 1, 1],
2143
+ padding=self.padding,
2144
+ )
2145
+ rr2 = self.backend.conv2d(
2146
+ self.backend.bk_imag(tim),
2147
+ self.ww_RealT[odata],
2148
+ strides=[1, 1, 1, 1],
2149
+ padding=self.padding,
2150
+ )
2151
+ ii2 = self.backend.conv2d(
2152
+ self.backend.bk_imag(tim),
2153
+ self.ww_ImagT[odata],
2154
+ strides=[1, 1, 1, 1],
2155
+ padding=self.padding,
2156
+ )
2157
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1677
2158
  else:
1678
- rr=self.backend.conv2d(tim,self.ww_RealT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1679
- ii=self.backend.conv2d(tim,self.ww_ImagT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1680
- res=self.backend.bk_complex(rr,ii)
1681
-
1682
- if axis==0:
1683
- if len(ishape)==2:
1684
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2],self.NORIENT])
2159
+ rr = self.backend.conv2d(
2160
+ tim,
2161
+ self.ww_RealT[odata],
2162
+ strides=[1, 1, 1, 1],
2163
+ padding=self.padding,
2164
+ )
2165
+ ii = self.backend.conv2d(
2166
+ tim,
2167
+ self.ww_ImagT[odata],
2168
+ strides=[1, 1, 1, 1],
2169
+ padding=self.padding,
2170
+ )
2171
+ res = self.backend.bk_complex(rr, ii)
2172
+
2173
+ if axis == 0:
2174
+ if len(ishape) == 2:
2175
+ return self.backend.bk_reshape(
2176
+ res, [res.shape[1], res.shape[2], self.NORIENT]
2177
+ )
1685
2178
  else:
1686
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
2179
+ return self.backend.bk_reshape(
2180
+ res,
2181
+ [res.shape[1], res.shape[2], self.NORIENT] + ishape[axis + 2 :],
2182
+ )
1687
2183
  else:
1688
- if len(ishape)==axis+2:
1689
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT])
2184
+ if len(ishape) == axis + 2:
2185
+ return self.backend.bk_reshape(
2186
+ res, ishape[0:axis] + [res.shape[1], res.shape[2], self.NORIENT]
2187
+ )
1690
2188
  else:
1691
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2],self.NORIENT]+ishape[axis+2:])
1692
-
1693
- return self.backend.bk_reshape(res,[nout,nouty])
1694
- elif self.use_1D==True:
1695
- ishape=list(in_image.shape)
1696
- if len(ishape)<axis+1:
2189
+ return self.backend.bk_reshape(
2190
+ res,
2191
+ ishape[0:axis]
2192
+ + [res.shape[1], res.shape[2], self.NORIENT]
2193
+ + ishape[axis + 2 :],
2194
+ )
2195
+
2196
+ return self.backend.bk_reshape(res, in_image.shape+[self.NORIENT])
2197
+ elif self.use_1D:
2198
+ ishape = list(in_image.shape)
2199
+ if len(ishape) < axis + 1:
1697
2200
  if not self.silent:
1698
- print('Use of 1D scat with data that has less than 1D')
2201
+ print("Use of 1D scat with data that has less than 1D")
1699
2202
  return None
1700
-
1701
- npix=ishape[axis]
1702
- odata=1
1703
- if len(ishape)>axis+1:
1704
- for k in range(axis+1,len(ishape)):
1705
- odata=odata*ishape[k]
1706
-
1707
- ndata=1
2203
+
2204
+ npix = ishape[axis]
2205
+ odata = 1
2206
+ if len(ishape) > axis + 1:
2207
+ for k in range(axis + 1, len(ishape)):
2208
+ odata = odata * ishape[k]
2209
+
2210
+ ndata = 1
1708
2211
  for k in range(axis):
1709
- ndata=ndata*ishape[k]
2212
+ ndata = ndata * ishape[k]
1710
2213
 
1711
- tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
2214
+ tim = self.backend.bk_reshape(
2215
+ self.backend.bk_cast(in_image), [ndata, npix, odata]
2216
+ )
1712
2217
 
1713
2218
  if self.backend.bk_is_complex(tim):
1714
- rr1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1715
- ii1=self.backend.conv1d(self.backend.bk_real(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1716
- rr2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1717
- ii2=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1718
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2219
+ rr1 = self.backend.conv1d(
2220
+ self.backend.bk_real(tim),
2221
+ self.ww_RealT[odata],
2222
+ strides=[1, 1, 1],
2223
+ padding=self.padding,
2224
+ )
2225
+ ii1 = self.backend.conv1d(
2226
+ self.backend.bk_real(tim),
2227
+ self.ww_ImagT[odata],
2228
+ strides=[1, 1, 1],
2229
+ padding=self.padding,
2230
+ )
2231
+ rr2 = self.backend.conv1d(
2232
+ self.backend.bk_imag(tim),
2233
+ self.ww_RealT[odata],
2234
+ strides=[1, 1, 1],
2235
+ padding=self.padding,
2236
+ )
2237
+ ii2 = self.backend.conv1d(
2238
+ self.backend.bk_imag(tim),
2239
+ self.ww_ImagT[odata],
2240
+ strides=[1, 1, 1],
2241
+ padding=self.padding,
2242
+ )
2243
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1719
2244
  else:
1720
- rr=self.backend.conv1d(tim,self.ww_RealT[odata],strides=[1, 1, 1],padding=self.padding)
1721
- ii=self.backend.conv1d(tim,self.ww_ImagT[odata],strides=[1, 1, 1],padding=self.padding)
1722
- res=self.backend.bk_complex(rr,ii)
1723
-
1724
- if axis==0:
1725
- if len(ishape)==1:
1726
- return self.backend.bk_reshape(res,[res.shape[1]])
2245
+ rr = self.backend.conv1d(
2246
+ tim, self.ww_RealT[odata], strides=[1, 1, 1], padding=self.padding
2247
+ )
2248
+ ii = self.backend.conv1d(
2249
+ tim, self.ww_ImagT[odata], strides=[1, 1, 1], padding=self.padding
2250
+ )
2251
+ res = self.backend.bk_complex(rr, ii)
2252
+
2253
+ if axis == 0:
2254
+ if len(ishape) == 1:
2255
+ return self.backend.bk_reshape(res, [res.shape[1]])
1727
2256
  else:
1728
- return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+2:])
2257
+ return self.backend.bk_reshape(
2258
+ res, [res.shape[1]] + ishape[axis + 2 :]
2259
+ )
1729
2260
  else:
1730
- if len(ishape)==axis+1:
1731
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
2261
+ if len(ishape) == axis + 1:
2262
+ return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
1732
2263
  else:
1733
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1734
-
1735
- return self.backend.bk_reshape(res,[nout,nouty])
1736
-
1737
-
2264
+ return self.backend.bk_reshape(
2265
+ res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2266
+ )
2267
+
2268
+ return self.backend.bk_reshape(res, in_image.shape+[self.NORIENT])
2269
+
1738
2270
  else:
1739
- nside=int(np.sqrt(image.shape[axis]//12))
2271
+ nside = int(np.sqrt(image.shape[axis] // 12))
1740
2272
 
1741
2273
  if self.Idx_Neighbours[nside] is None:
1742
2274
  if self.InitWave is None:
1743
- wr,wi,ws,widx=self.init_index(nside)
2275
+ wr, wi, ws, widx = self.init_index(nside)
1744
2276
  else:
1745
- wr,wi,ws,widx=self.InitWave(self,nside)
1746
-
1747
- self.Idx_Neighbours[nside]=1 #self.backend.constant(tmp)
1748
- self.ww_Real[nside]=wr
1749
- self.ww_Imag[nside]=wi
1750
- self.w_smooth[nside]=ws
1751
-
1752
- l_ww_real=self.ww_Real[nside]
1753
- l_ww_imag=self.ww_Imag[nside]
1754
-
1755
- ishape=list(image.shape)
1756
- odata=1
1757
- for k in range(axis+1,len(ishape)):
1758
- odata=odata*ishape[k]
1759
-
1760
- if axis>0:
1761
- ndata=1
2277
+ wr, wi, ws, widx = self.InitWave(self, nside)
2278
+
2279
+ self.Idx_Neighbours[nside] = 1 # self.backend.constant(tmp)
2280
+ self.ww_Real[nside] = wr
2281
+ self.ww_Imag[nside] = wi
2282
+ self.w_smooth[nside] = ws
2283
+
2284
+ l_ww_real = self.ww_Real[nside]
2285
+ l_ww_imag = self.ww_Imag[nside]
2286
+
2287
+ ishape = list(image.shape)
2288
+ odata = 1
2289
+ for k in range(axis + 1, len(ishape)):
2290
+ odata = odata * ishape[k]
2291
+
2292
+ if axis > 0:
2293
+ ndata = 1
1762
2294
  for k in range(axis):
1763
- ndata=ndata*ishape[k]
1764
- tim=self.backend.bk_reshape(self.backend.bk_cast(image),[ndata,12*nside**2,odata])
1765
- if tim.dtype==self.all_cbk_type:
1766
- rr1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_real(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1767
- ii1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_real(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1768
- rr2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_imag(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1769
- ii2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_imag(tim[0])),[1,12*nside**2,self.NORIENT,odata])
1770
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2295
+ ndata = ndata * ishape[k]
2296
+ tim = self.backend.bk_reshape(
2297
+ self.backend.bk_cast(image), [ndata, 12 * nside**2, odata]
2298
+ )
2299
+ if tim.dtype == self.all_cbk_type:
2300
+ rr1 = self.backend.bk_reshape(
2301
+ self.backend.bk_sparse_dense_matmul(
2302
+ l_ww_real, self.backend.bk_real(tim[0])
2303
+ ),
2304
+ [1, 12 * nside**2, self.NORIENT, odata],
2305
+ )
2306
+ ii1 = self.backend.bk_reshape(
2307
+ self.backend.bk_sparse_dense_matmul(
2308
+ l_ww_imag, self.backend.bk_real(tim[0])
2309
+ ),
2310
+ [1, 12 * nside**2, self.NORIENT, odata],
2311
+ )
2312
+ rr2 = self.backend.bk_reshape(
2313
+ self.backend.bk_sparse_dense_matmul(
2314
+ l_ww_real, self.backend.bk_imag(tim[0])
2315
+ ),
2316
+ [1, 12 * nside**2, self.NORIENT, odata],
2317
+ )
2318
+ ii2 = self.backend.bk_reshape(
2319
+ self.backend.bk_sparse_dense_matmul(
2320
+ l_ww_imag, self.backend.bk_imag(tim[0])
2321
+ ),
2322
+ [1, 12 * nside**2, self.NORIENT, odata],
2323
+ )
2324
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1771
2325
  else:
1772
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,tim[0]),[1,12*nside**2,self.NORIENT,odata])
1773
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,tim[0]),[1,12*nside**2,self.NORIENT,odata])
1774
- res=self.backend.bk_complex(rr,ii)
1775
-
1776
- for k in range(1,ndata):
1777
- if tim.dtype==self.all_cbk_type:
1778
- rr1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_real(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1779
- ii1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_real(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1780
- rr2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_imag(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1781
- ii2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_imag(tim[k])),[1,12*nside**2,self.NORIENT,odata])
1782
- res=self.backend.bk_concat([res,self.backend.bk_complex(rr1-ii2,ii1+rr2)],0)
2326
+ rr = self.backend.bk_reshape(
2327
+ self.backend.bk_sparse_dense_matmul(l_ww_real, tim[0]),
2328
+ [1, 12 * nside**2, self.NORIENT, odata],
2329
+ )
2330
+ ii = self.backend.bk_reshape(
2331
+ self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[0]),
2332
+ [1, 12 * nside**2, self.NORIENT, odata],
2333
+ )
2334
+ res = self.backend.bk_complex(rr, ii)
2335
+
2336
+ for k in range(1, ndata):
2337
+ if tim.dtype == self.all_cbk_type:
2338
+ rr1 = self.backend.bk_reshape(
2339
+ self.backend.bk_sparse_dense_matmul(
2340
+ l_ww_real, self.backend.bk_real(tim[k])
2341
+ ),
2342
+ [1, 12 * nside**2, self.NORIENT, odata],
2343
+ )
2344
+ ii1 = self.backend.bk_reshape(
2345
+ self.backend.bk_sparse_dense_matmul(
2346
+ l_ww_imag, self.backend.bk_real(tim[k])
2347
+ ),
2348
+ [1, 12 * nside**2, self.NORIENT, odata],
2349
+ )
2350
+ rr2 = self.backend.bk_reshape(
2351
+ self.backend.bk_sparse_dense_matmul(
2352
+ l_ww_real, self.backend.bk_imag(tim[k])
2353
+ ),
2354
+ [1, 12 * nside**2, self.NORIENT, odata],
2355
+ )
2356
+ ii2 = self.backend.bk_reshape(
2357
+ self.backend.bk_sparse_dense_matmul(
2358
+ l_ww_imag, self.backend.bk_imag(tim[k])
2359
+ ),
2360
+ [1, 12 * nside**2, self.NORIENT, odata],
2361
+ )
2362
+ res = self.backend.bk_concat(
2363
+ [res, self.backend.bk_complex(rr1 - ii2, ii1 + rr2)], 0
2364
+ )
1783
2365
  else:
1784
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,tim[k]),[1,12*nside**2,self.NORIENT,odata])
1785
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,tim[k]),[1,12*nside**2,self.NORIENT,odata])
1786
- res=self.backend.bk_concat([res,self.backend.bk_complex(rr,ii)],0)
1787
-
1788
- if len(ishape)==axis+1:
1789
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2,self.NORIENT])
2366
+ rr = self.backend.bk_reshape(
2367
+ self.backend.bk_sparse_dense_matmul(l_ww_real, tim[k]),
2368
+ [1, 12 * nside**2, self.NORIENT, odata],
2369
+ )
2370
+ ii = self.backend.bk_reshape(
2371
+ self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[k]),
2372
+ [1, 12 * nside**2, self.NORIENT, odata],
2373
+ )
2374
+ res = self.backend.bk_concat(
2375
+ [res, self.backend.bk_complex(rr, ii)], 0
2376
+ )
2377
+
2378
+ if len(ishape) == axis + 1:
2379
+ return self.backend.bk_reshape(
2380
+ res, ishape[0:axis] + [12 * nside**2, self.NORIENT]
2381
+ )
1790
2382
  else:
1791
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2]+ishape[axis+1:]+[self.NORIENT])
1792
-
1793
- if axis==0:
1794
- tim=self.backend.bk_reshape(self.backend.bk_cast(image),[12*nside**2,odata])
1795
- if tim.dtype==self.all_cbk_type:
1796
- rr1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_real(tim)),[12*nside**2,self.NORIENT,odata])
1797
- ii1=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_real(tim)),[12*nside**2,self.NORIENT,odata])
1798
- rr2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,self.backend.bk_imag(tim)),[12*nside**2,self.NORIENT,odata])
1799
- ii2=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,self.backend.bk_imag(tim)),[12*nside**2,self.NORIENT,odata])
1800
- res=self.backend.bk_complex(rr1-ii2,ii1+rr2)
2383
+ return self.backend.bk_reshape(
2384
+ res,
2385
+ ishape[0:axis]
2386
+ + [12 * nside**2]
2387
+ + ishape[axis + 1 :]
2388
+ + [self.NORIENT],
2389
+ )
2390
+
2391
+ if axis == 0:
2392
+ tim = self.backend.bk_reshape(
2393
+ self.backend.bk_cast(image), [12 * nside**2, odata]
2394
+ )
2395
+ if tim.dtype == self.all_cbk_type:
2396
+ rr1 = self.backend.bk_reshape(
2397
+ self.backend.bk_sparse_dense_matmul(
2398
+ l_ww_real, self.backend.bk_real(tim)
2399
+ ),
2400
+ [12 * nside**2, self.NORIENT, odata],
2401
+ )
2402
+ ii1 = self.backend.bk_reshape(
2403
+ self.backend.bk_sparse_dense_matmul(
2404
+ l_ww_imag, self.backend.bk_real(tim)
2405
+ ),
2406
+ [12 * nside**2, self.NORIENT, odata],
2407
+ )
2408
+ rr2 = self.backend.bk_reshape(
2409
+ self.backend.bk_sparse_dense_matmul(
2410
+ l_ww_real, self.backend.bk_imag(tim)
2411
+ ),
2412
+ [12 * nside**2, self.NORIENT, odata],
2413
+ )
2414
+ ii2 = self.backend.bk_reshape(
2415
+ self.backend.bk_sparse_dense_matmul(
2416
+ l_ww_imag, self.backend.bk_imag(tim)
2417
+ ),
2418
+ [12 * nside**2, self.NORIENT, odata],
2419
+ )
2420
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
1801
2421
  else:
1802
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_real,tim),[12*nside**2,self.NORIENT,odata])
1803
- ii=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_ww_imag,tim),[12*nside**2,self.NORIENT,odata])
1804
- res=self.backend.bk_complex(rr,ii)
1805
-
1806
- if len(ishape)==1:
1807
- return self.backend.bk_reshape(res,[12*nside**2,self.NORIENT])
2422
+ rr = self.backend.bk_reshape(
2423
+ self.backend.bk_sparse_dense_matmul(l_ww_real, tim),
2424
+ [12 * nside**2, self.NORIENT, odata],
2425
+ )
2426
+ ii = self.backend.bk_reshape(
2427
+ self.backend.bk_sparse_dense_matmul(l_ww_imag, tim),
2428
+ [12 * nside**2, self.NORIENT, odata],
2429
+ )
2430
+ res = self.backend.bk_complex(rr, ii)
2431
+
2432
+ if len(ishape) == 1:
2433
+ return self.backend.bk_reshape(res, [12 * nside**2, self.NORIENT])
1808
2434
  else:
1809
- return self.backend.bk_reshape(res,[12*nside**2]+ishape[axis+1:]+[self.NORIENT])
1810
- return(res)
1811
-
2435
+ return self.backend.bk_reshape(
2436
+ res, [12 * nside**2] + ishape[axis + 1 :] + [self.NORIENT]
2437
+ )
2438
+ return res
1812
2439
 
1813
2440
  # ---------------------------------------------−---------
1814
- def smooth(self,in_image,axis=0):
2441
+ def smooth(self, in_image, axis=0):
2442
+
2443
+ image = self.backend.bk_cast(in_image)
1815
2444
 
1816
- image=self.backend.bk_cast(in_image)
1817
-
1818
2445
  if self.use_2D:
1819
-
1820
- ishape=list(in_image.shape)
1821
- if len(ishape)<axis+2:
2446
+
2447
+ ishape = list(in_image.shape)
2448
+ if len(ishape) < axis + 2:
1822
2449
  if not self.silent:
1823
- print('Use of 2D scat with data that has less than 2D')
2450
+ print("Use of 2D scat with data that has less than 2D")
1824
2451
  return None
1825
-
1826
- npix=ishape[axis]
1827
- npiy=ishape[axis+1]
1828
- odata=1
1829
- if len(ishape)>axis+2:
1830
- for k in range(axis+2,len(ishape)):
1831
- odata=odata*ishape[k]
1832
-
1833
- ndata=1
2452
+
2453
+ npix = ishape[axis]
2454
+ npiy = ishape[axis + 1]
2455
+ odata = 1
2456
+ if len(ishape) > axis + 2:
2457
+ for k in range(axis + 2, len(ishape)):
2458
+ odata = odata * ishape[k]
2459
+
2460
+ ndata = 1
1834
2461
  for k in range(axis):
1835
- ndata=ndata*ishape[k]
2462
+ ndata = ndata * ishape[k]
1836
2463
 
1837
- tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,npiy,odata])
2464
+ tim = self.backend.bk_reshape(
2465
+ self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2466
+ )
1838
2467
 
1839
2468
  if self.backend.bk_is_complex(tim):
1840
- rr=self.backend.conv2d(self.backend.bk_real(tim),self.ww_SmoothT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1841
- ii=self.backend.conv2d(self.backend.bk_imag(tim),self.ww_SmoothT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1842
- res=self.backend.bk_complex(rr,ii)
2469
+ rr = self.backend.conv2d(
2470
+ self.backend.bk_real(tim),
2471
+ self.ww_SmoothT[odata],
2472
+ strides=[1, 1, 1, 1],
2473
+ padding=self.padding,
2474
+ )
2475
+ ii = self.backend.conv2d(
2476
+ self.backend.bk_imag(tim),
2477
+ self.ww_SmoothT[odata],
2478
+ strides=[1, 1, 1, 1],
2479
+ padding=self.padding,
2480
+ )
2481
+ res = self.backend.bk_complex(rr, ii)
1843
2482
  else:
1844
- res=self.backend.conv2d(tim,self.ww_SmoothT[odata],strides=[1, 1, 1, 1],padding=self.padding)
1845
-
1846
- if axis==0:
1847
- if len(ishape)==2:
1848
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2]])
2483
+ res = self.backend.conv2d(
2484
+ tim,
2485
+ self.ww_SmoothT[odata],
2486
+ strides=[1, 1, 1, 1],
2487
+ padding=self.padding,
2488
+ )
2489
+
2490
+ if axis == 0:
2491
+ if len(ishape) == 2:
2492
+ return self.backend.bk_reshape(res, [res.shape[1], res.shape[2]])
1849
2493
  else:
1850
- return self.backend.bk_reshape(res,[res.shape[1],res.shape[2]]+ishape[axis+2:])
2494
+ return self.backend.bk_reshape(
2495
+ res, [res.shape[1], res.shape[2]] + ishape[axis + 2 :]
2496
+ )
1851
2497
  else:
1852
- if len(ishape)==axis+2:
1853
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]])
2498
+ if len(ishape) == axis + 2:
2499
+ return self.backend.bk_reshape(
2500
+ res, ishape[0:axis] + [res.shape[1], res.shape[2]]
2501
+ )
1854
2502
  else:
1855
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1],res.shape[2]]+ishape[axis+2:])
1856
-
1857
- return self.backend.bk_reshape(res,[nout,nouty])
2503
+ return self.backend.bk_reshape(
2504
+ res,
2505
+ ishape[0:axis]
2506
+ + [res.shape[1], res.shape[2]]
2507
+ + ishape[axis + 2 :],
2508
+ )
2509
+
2510
+ return self.backend.bk_reshape(res, in_image.shape)
1858
2511
  elif self.use_1D:
1859
-
1860
- ishape=list(in_image.shape)
1861
- if len(ishape)<axis+1:
2512
+
2513
+ ishape = list(in_image.shape)
2514
+ if len(ishape) < axis + 1:
1862
2515
  if not self.silent:
1863
- print('Use of 1D scat with data that has less than 1D')
2516
+ print("Use of 1D scat with data that has less than 1D")
1864
2517
  return None
1865
-
1866
- npix=ishape[axis]
1867
- odata=1
1868
- if len(ishape)>axis+1:
1869
- for k in range(axis+1,len(ishape)):
1870
- odata=odata*ishape[k]
1871
-
1872
- ndata=1
2518
+
2519
+ npix = ishape[axis]
2520
+ odata = 1
2521
+ if len(ishape) > axis + 1:
2522
+ for k in range(axis + 1, len(ishape)):
2523
+ odata = odata * ishape[k]
2524
+
2525
+ ndata = 1
1873
2526
  for k in range(axis):
1874
- ndata=ndata*ishape[k]
2527
+ ndata = ndata * ishape[k]
1875
2528
 
1876
- tim=self.backend.bk_reshape(self.backend.bk_cast(in_image),[ndata,npix,odata])
2529
+ tim = self.backend.bk_reshape(
2530
+ self.backend.bk_cast(in_image), [ndata, npix, odata]
2531
+ )
1877
2532
 
1878
2533
  if self.backend.bk_is_complex(tim):
1879
- rr=self.backend.conv1d(self.backend.bk_real(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1880
- ii=self.backend.conv1d(self.backend.bk_imag(tim),self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1881
- res=self.backend.bk_complex(rr,ii)
2534
+ rr = self.backend.conv1d(
2535
+ self.backend.bk_real(tim),
2536
+ self.ww_SmoothT[odata],
2537
+ strides=[1, 1, 1],
2538
+ padding=self.padding,
2539
+ )
2540
+ ii = self.backend.conv1d(
2541
+ self.backend.bk_imag(tim),
2542
+ self.ww_SmoothT[odata],
2543
+ strides=[1, 1, 1],
2544
+ padding=self.padding,
2545
+ )
2546
+ res = self.backend.bk_complex(rr, ii)
1882
2547
  else:
1883
- res=self.backend.conv1d(tim,self.ww_SmoothT[odata],strides=[1, 1, 1],padding=self.padding)
1884
-
1885
- if axis==0:
1886
- if len(ishape)==1:
1887
- return self.backend.bk_reshape(res,[res.shape[1]])
2548
+ res = self.backend.conv1d(
2549
+ tim, self.ww_SmoothT[odata], strides=[1, 1, 1], padding=self.padding
2550
+ )
2551
+
2552
+ if axis == 0:
2553
+ if len(ishape) == 1:
2554
+ return self.backend.bk_reshape(res, [res.shape[1]])
1888
2555
  else:
1889
- return self.backend.bk_reshape(res,[res.shape[1]]+ishape[axis+1:])
2556
+ return self.backend.bk_reshape(
2557
+ res, [res.shape[1]] + ishape[axis + 1 :]
2558
+ )
1890
2559
  else:
1891
- if len(ishape)==axis+1:
1892
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]])
2560
+ if len(ishape) == axis + 1:
2561
+ return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
1893
2562
  else:
1894
- return self.backend.bk_reshape(res,ishape[0:axis]+[res.shape[1]]+ishape[axis+1:])
1895
-
1896
- return self.backend.bk_reshape(res,[nout,nouty])
1897
-
2563
+ return self.backend.bk_reshape(
2564
+ res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2565
+ )
2566
+
2567
+ return self.backend.bk_reshape(res, in_image.shape)
2568
+
1898
2569
  else:
1899
- nside=int(np.sqrt(image.shape[axis]//12))
2570
+ nside = int(np.sqrt(image.shape[axis] // 12))
1900
2571
 
1901
2572
  if self.Idx_Neighbours[nside] is None:
1902
-
2573
+
1903
2574
  if self.InitWave is None:
1904
- wr,wi,ws,widx=self.init_index(nside)
2575
+ wr, wi, ws, widx = self.init_index(nside)
1905
2576
  else:
1906
- wr,wi,ws,widx=self.InitWave(self,nside)
1907
-
1908
- self.Idx_Neighbours[nside]=1
1909
- self.ww_Real[nside]=wr
1910
- self.ww_Imag[nside]=wi
1911
- self.w_smooth[nside]=ws
1912
-
1913
- l_w_smooth=self.w_smooth[nside]
1914
- ishape=list(image.shape)
1915
-
1916
- odata=1
1917
- for k in range(axis+1,len(ishape)):
1918
- odata=odata*ishape[k]
1919
-
1920
- if axis==0:
1921
- tim=self.backend.bk_reshape(image,[12*nside**2,odata])
1922
- if tim.dtype==self.all_cbk_type:
1923
- rr=self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_real(tim))
1924
- ri=self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_imag(tim))
1925
- res=self.backend.bk_complex(rr,ri)
2577
+ wr, wi, ws, widx = self.InitWave(self, nside)
2578
+
2579
+ self.Idx_Neighbours[nside] = 1
2580
+ self.ww_Real[nside] = wr
2581
+ self.ww_Imag[nside] = wi
2582
+ self.w_smooth[nside] = ws
2583
+
2584
+ l_w_smooth = self.w_smooth[nside]
2585
+ ishape = list(image.shape)
2586
+
2587
+ odata = 1
2588
+ for k in range(axis + 1, len(ishape)):
2589
+ odata = odata * ishape[k]
2590
+
2591
+ if axis == 0:
2592
+ tim = self.backend.bk_reshape(image, [12 * nside**2, odata])
2593
+ if tim.dtype == self.all_cbk_type:
2594
+ rr = self.backend.bk_sparse_dense_matmul(
2595
+ l_w_smooth, self.backend.bk_real(tim)
2596
+ )
2597
+ ri = self.backend.bk_sparse_dense_matmul(
2598
+ l_w_smooth, self.backend.bk_imag(tim)
2599
+ )
2600
+ res = self.backend.bk_complex(rr, ri)
1926
2601
  else:
1927
- res=self.backend.bk_sparse_dense_matmul(l_w_smooth,tim)
1928
- if len(ishape)==1:
1929
- return self.backend.bk_reshape(res,[12*nside**2])
2602
+ res = self.backend.bk_sparse_dense_matmul(l_w_smooth, tim)
2603
+ if len(ishape) == 1:
2604
+ return self.backend.bk_reshape(res, [12 * nside**2])
1930
2605
  else:
1931
- return self.backend.bk_reshape(res,[12*nside**2]+ishape[axis+1:])
1932
-
1933
- if axis>0:
1934
- ndata=ishape[0]
1935
- for k in range(1,axis):
1936
- ndata=ndata*ishape[k]
1937
- tim=self.backend.bk_reshape(image,[ndata,12*nside**2,odata])
1938
- if tim.dtype==self.all_cbk_type:
1939
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_real(tim[0])),[1,12*nside**2,odata])
1940
- ri=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_imag(tim[0])),[1,12*nside**2,odata])
1941
- res=self.backend.bk_complex(rr,ri)
2606
+ return self.backend.bk_reshape(
2607
+ res, [12 * nside**2] + ishape[axis + 1 :]
2608
+ )
2609
+
2610
+ if axis > 0:
2611
+ ndata = ishape[0]
2612
+ for k in range(1, axis):
2613
+ ndata = ndata * ishape[k]
2614
+ tim = self.backend.bk_reshape(image, [ndata, 12 * nside**2, odata])
2615
+ if tim.dtype == self.all_cbk_type:
2616
+ rr = self.backend.bk_reshape(
2617
+ self.backend.bk_sparse_dense_matmul(
2618
+ l_w_smooth, self.backend.bk_real(tim[0])
2619
+ ),
2620
+ [1, 12 * nside**2, odata],
2621
+ )
2622
+ ri = self.backend.bk_reshape(
2623
+ self.backend.bk_sparse_dense_matmul(
2624
+ l_w_smooth, self.backend.bk_imag(tim[0])
2625
+ ),
2626
+ [1, 12 * nside**2, odata],
2627
+ )
2628
+ res = self.backend.bk_complex(rr, ri)
1942
2629
  else:
1943
- res=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,tim[0]),[1,12*nside**2,odata])
1944
-
1945
- for k in range(1,ndata):
1946
- if tim.dtype==self.all_cbk_type:
1947
- rr=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_real(tim[k])),[1,12*nside**2,odata])
1948
- ri=self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,self.backend.bk_imag(tim[k])),[1,12*nside**2,odata])
1949
- res=self.backend.bk_concat([res,self.backend.bk_complex(rr,ri)],0)
2630
+ res = self.backend.bk_reshape(
2631
+ self.backend.bk_sparse_dense_matmul(l_w_smooth, tim[0]),
2632
+ [1, 12 * nside**2, odata],
2633
+ )
2634
+
2635
+ for k in range(1, ndata):
2636
+ if tim.dtype == self.all_cbk_type:
2637
+ rr = self.backend.bk_reshape(
2638
+ self.backend.bk_sparse_dense_matmul(
2639
+ l_w_smooth, self.backend.bk_real(tim[k])
2640
+ ),
2641
+ [1, 12 * nside**2, odata],
2642
+ )
2643
+ ri = self.backend.bk_reshape(
2644
+ self.backend.bk_sparse_dense_matmul(
2645
+ l_w_smooth, self.backend.bk_imag(tim[k])
2646
+ ),
2647
+ [1, 12 * nside**2, odata],
2648
+ )
2649
+ res = self.backend.bk_concat(
2650
+ [res, self.backend.bk_complex(rr, ri)], 0
2651
+ )
1950
2652
  else:
1951
- res=self.backend.bk_concat([res,self.backend.bk_reshape(self.backend.bk_sparse_dense_matmul(l_w_smooth,tim[k]),[1,12*nside**2,odata])],0)
1952
-
1953
- if len(ishape)==axis+1:
1954
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2])
2653
+ res = self.backend.bk_concat(
2654
+ [
2655
+ res,
2656
+ self.backend.bk_reshape(
2657
+ self.backend.bk_sparse_dense_matmul(
2658
+ l_w_smooth, tim[k]
2659
+ ),
2660
+ [1, 12 * nside**2, odata],
2661
+ ),
2662
+ ],
2663
+ 0,
2664
+ )
2665
+
2666
+ if len(ishape) == axis + 1:
2667
+ return self.backend.bk_reshape(
2668
+ res, ishape[0:axis] + [12 * nside**2]
2669
+ )
1955
2670
  else:
1956
- return self.backend.bk_reshape(res,ishape[0:axis]+[12*nside**2]+ishape[axis+1:])
1957
-
1958
-
1959
- return(res)
1960
-
2671
+ return self.backend.bk_reshape(
2672
+ res, ishape[0:axis] + [12 * nside**2] + ishape[axis + 1 :]
2673
+ )
2674
+
2675
+ return res
2676
+
1961
2677
  # ---------------------------------------------−---------
1962
2678
  def get_kernel_size(self):
1963
- return(self.KERNELSZ)
1964
-
2679
+ return self.KERNELSZ
2680
+
1965
2681
  # ---------------------------------------------−---------
1966
2682
  def get_nb_orient(self):
1967
- return(self.NORIENT)
1968
-
2683
+ return self.NORIENT
2684
+
1969
2685
  # ---------------------------------------------−---------
1970
- def get_ww(self,nside=1):
1971
- return(self.ww_Real[nside],self.ww_Imag[nside])
1972
-
2686
+ def get_ww(self, nside=1):
2687
+ return (self.ww_Real[nside], self.ww_Imag[nside])
2688
+
1973
2689
  # ---------------------------------------------−---------
1974
2690
  def plot_ww(self):
1975
- c,s=self.get_ww()
2691
+ c, s = self.get_ww()
1976
2692
  import matplotlib.pyplot as plt
1977
- plt.figure(figsize=(16,6))
1978
- npt=int(np.sqrt(c.shape[0]))
2693
+
2694
+ plt.figure(figsize=(16, 6))
2695
+ npt = int(np.sqrt(c.shape[0]))
1979
2696
  for i in range(c.shape[1]):
1980
- plt.subplot(2,c.shape[1],1+i)
1981
- plt.imshow(c[:,i].reshape(npt,npt),cmap='jet',vmin=-c.max(),vmax=c.max())
1982
- plt.subplot(2,c.shape[1],1+i+c.shape[1])
1983
- plt.imshow(s[:,i].reshape(npt,npt),cmap='jet',vmin=-c.max(),vmax=c.max())
2697
+ plt.subplot(2, c.shape[1], 1 + i)
2698
+ plt.imshow(
2699
+ c[:, i].reshape(npt, npt), cmap="jet", vmin=-c.max(), vmax=c.max()
2700
+ )
2701
+ plt.subplot(2, c.shape[1], 1 + i + c.shape[1])
2702
+ plt.imshow(
2703
+ s[:, i].reshape(npt, npt), cmap="jet", vmin=-c.max(), vmax=c.max()
2704
+ )
1984
2705
  sys.stdout.flush()
1985
2706
  plt.show()
1986
-
1987
-
1988
-
1989
-
1990
-