foscat 3.1.6__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat.py CHANGED
@@ -1,433 +1,602 @@
1
- import foscat.FoCUS as FOC
2
- import numpy as np
3
1
  import pickle
4
- import foscat.backend as bk
5
- import healpy as hp
6
2
  import sys
7
-
3
+
4
+ import healpy as hp
5
+ import numpy as np
6
+
7
+ import foscat.backend as bk
8
+ import foscat.FoCUS as FOC
9
+
8
10
  # Vérifier si TensorFlow est importé et défini
9
- tf_defined = 'tensorflow' in sys.modules
11
+ tf_defined = "tensorflow" in sys.modules
10
12
 
11
13
  if tf_defined:
12
14
  import tensorflow as tf
13
- tf_function = tf.function # Facultatif : si vous voulez utiliser TensorFlow dans ce script
15
+
16
+ tf_function = (
17
+ tf.function
18
+ ) # Facultatif : si vous voulez utiliser TensorFlow dans ce script
14
19
  else:
20
+
15
21
  def tf_function(func):
16
22
  return func
17
-
23
+
24
+
18
25
  def read(filename):
19
- thescat=scat(1,1,1,1,1,[0],[0])
26
+ thescat = scat(1, 1, 1, 1, 1, [0], [0])
20
27
  return thescat.read(filename)
21
-
28
+
29
+
22
30
  class scat:
23
- def __init__(self,p00,s0,s1,s2,s2l,j1,j2,cross=False,backend=None):
24
- self.bk_type='SCAT'
25
- self.P00=p00
26
- self.S0=s0
27
- self.S1=s1
28
- self.S2=s2
29
- self.S2L=s2l
30
- self.j1=j1
31
- self.j2=j2
32
- self.cross=cross
33
- self.backend=backend
34
-
35
- def set_bk_type(self,bk_type):
36
- self.bk_type=bk_type
37
-
31
+ def __init__(self, p00, s0, s1, s2, s2l, j1, j2, cross=False, backend=None):
32
+ self.bk_type = "SCAT"
33
+ self.P00 = p00
34
+ self.S0 = s0
35
+ self.S1 = s1
36
+ self.S2 = s2
37
+ self.S2L = s2l
38
+ self.j1 = j1
39
+ self.j2 = j2
40
+ self.cross = cross
41
+ self.backend = backend
42
+
43
+ def set_bk_type(self, bk_type):
44
+ self.bk_type = bk_type
45
+
38
46
  def get_j_idx(self):
39
- return self.j1,self.j2
40
-
47
+ return self.j1, self.j2
48
+
41
49
  def get_S0(self):
42
- return(self.S0)
50
+ return self.S0
43
51
 
44
52
  def get_S1(self):
45
- return(self.S1)
46
-
53
+ return self.S1
54
+
47
55
  def get_S2(self):
48
- return(self.S2)
56
+ return self.S2
49
57
 
50
58
  def get_S2L(self):
51
- return(self.S2L)
59
+ return self.S2L
52
60
 
53
61
  def get_P00(self):
54
- return(self.P00)
62
+ return self.P00
55
63
 
56
64
  def reset_P00(self):
57
- self.P00=0*self.P00
65
+ self.P00 = 0 * self.P00
58
66
 
59
67
  def constant(self):
60
- return scat(self.backend.constant(self.P00 ), \
61
- self.backend.constant(self.S0 ), \
62
- self.backend.constant(self.S1 ), \
63
- self.backend.constant(self.S2 ), \
64
- self.backend.constant(self.S2L ), \
65
- self.j1 , \
66
- self.j2 ,backend=self.backend)
67
-
68
-
69
- def domult(self,x,y):
68
+ return scat(
69
+ self.backend.constant(self.P00),
70
+ self.backend.constant(self.S0),
71
+ self.backend.constant(self.S1),
72
+ self.backend.constant(self.S2),
73
+ self.backend.constant(self.S2L),
74
+ self.j1,
75
+ self.j2,
76
+ backend=self.backend,
77
+ )
78
+
79
+ def domult(self, x, y):
70
80
  try:
71
- return x*y
81
+ return x * y
72
82
  except:
73
- if x.dtype==y.dtype:
74
- return x*y
83
+ if x.dtype == y.dtype:
84
+ return x * y
75
85
  if self.backend.bk_is_complex(x):
76
86
 
77
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
87
+ return self.backend.bk_complex(
88
+ self.backend.bk_real(x) * y, self.backend.bk_imag(x) * y
89
+ )
78
90
  else:
79
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
91
+ return self.backend.bk_complex(
92
+ self.backend.bk_real(y) * x, self.backend.bk_imag(y) * x
93
+ )
80
94
 
81
- def dodiv(self,x,y):
95
+ def dodiv(self, x, y):
82
96
  try:
83
- return x/y
97
+ return x / y
84
98
  except:
85
- if x.dtype==y.dtype:
86
- return x/y
99
+ if x.dtype == y.dtype:
100
+ return x / y
87
101
  if self.backend.bk_is_complex(x):
88
-
89
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
102
+
103
+ return self.backend.bk_complex(
104
+ self.backend.bk_real(x) / y, self.backend.bk_imag(x) / y
105
+ )
90
106
  else:
91
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
92
-
93
- def domin(self,x,y):
107
+ return self.backend.bk_complex(
108
+ x / self.backend.bk_real(y), x / self.backend.bk_imag(y)
109
+ )
110
+
111
+ def domin(self, x, y):
94
112
  try:
95
- return x-y
113
+ return x - y
96
114
  except:
97
- if x.dtype==y.dtype:
98
- return x-y
115
+ if x.dtype == y.dtype:
116
+ return x - y
99
117
 
100
118
  if self.backend.bk_is_complex(x):
101
119
 
102
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
120
+ return self.backend.bk_complex(
121
+ self.backend.bk_real(x) - y, self.backend.bk_imag(x) - y
122
+ )
103
123
  else:
104
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
105
-
106
- def doadd(self,x,y):
124
+ return self.backend.bk_complex(
125
+ x - self.backend.bk_real(y), x - self.backend.bk_imag(y)
126
+ )
127
+
128
+ def doadd(self, x, y):
107
129
  try:
108
- return x+y
130
+ return x + y
109
131
  except:
110
- if x.dtype==y.dtype:
111
- return x+y
132
+ if x.dtype == y.dtype:
133
+ return x + y
112
134
  if self.backend.bk_is_complex(x):
113
135
 
114
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
136
+ return self.backend.bk_complex(
137
+ self.backend.bk_real(x) + y, self.backend.bk_imag(x) + y
138
+ )
115
139
  else:
116
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
117
-
118
- def relu(self):
119
-
120
- return scat(self.backend.bk_relu(self.P00), \
121
- self.backend.bk_relu(self.S0), \
122
- self.backend.bk_relu(self.S1), \
123
- self.backend.bk_relu(self.S2), \
124
- self.backend.bk_relu(self.S2L), \
125
- self.j1,self.j2,backend=self.backend)
126
-
127
- def __add__(self,other):
128
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
129
- isinstance(other, bool) or isinstance(other, scat)
130
-
140
+ return self.backend.bk_complex(
141
+ x + self.backend.bk_real(y), x + self.backend.bk_imag(y)
142
+ )
143
+
144
+ def __add__(self, other):
145
+ assert (
146
+ isinstance(other, float)
147
+ or isinstance(other, np.float32)
148
+ or isinstance(other, int)
149
+ or isinstance(other, bool)
150
+ or isinstance(other, scat)
151
+ )
152
+
131
153
  if isinstance(other, scat):
132
- return scat(self.doadd(self.P00,other.P00), \
133
- self.doadd(self.S0, other.S0), \
134
- self.doadd(self.S1, other.S1), \
135
- self.doadd(self.S2, other.S2), \
136
- self.doadd(self.S2L, other.S2L), \
137
- self.j1,self.j2,backend=self.backend)
154
+ return scat(
155
+ self.doadd(self.P00, other.P00),
156
+ self.doadd(self.S0, other.S0),
157
+ self.doadd(self.S1, other.S1),
158
+ self.doadd(self.S2, other.S2),
159
+ self.doadd(self.S2L, other.S2L),
160
+ self.j1,
161
+ self.j2,
162
+ backend=self.backend,
163
+ )
138
164
  else:
139
- return scat((self.P00+ other), \
140
- (self.S0+ other), \
141
- (self.S1+ other), \
142
- (self.S2+ other), \
143
- (self.S2L+ other), \
144
- self.j1,self.j2,backend=self.backend)
145
-
146
- def toreal(self,value):
165
+ return scat(
166
+ (self.P00 + other),
167
+ (self.S0 + other),
168
+ (self.S1 + other),
169
+ (self.S2 + other),
170
+ (self.S2L + other),
171
+ self.j1,
172
+ self.j2,
173
+ backend=self.backend,
174
+ )
175
+
176
+ def toreal(self, value):
147
177
  if value is None:
148
178
  return None
149
-
179
+
150
180
  return self.backend.bk_real(value)
151
181
 
152
- def addcomplex(self,value,amp):
182
+ def addcomplex(self, value, amp):
153
183
  if value is None:
154
184
  return None
155
-
156
- return self.backend.bk_complex(value,amp*value)
157
-
158
- def add_complex(self,amp):
159
- return scat(self.addcomplex(self.P00,amp), \
160
- self.addcomplex(self.S0,amp), \
161
- self.addcomplex(self.S1,amp), \
162
- self.addcomplex(self.S2,amp), \
163
- self.addcomplex(self.S2L,amp), \
164
- self.j1,self.j2,backend=self.backend)
165
-
185
+
186
+ return self.backend.bk_complex(value, amp * value)
187
+
188
+ def add_complex(self, amp):
189
+ return scat(
190
+ self.addcomplex(self.P00, amp),
191
+ self.addcomplex(self.S0, amp),
192
+ self.addcomplex(self.S1, amp),
193
+ self.addcomplex(self.S2, amp),
194
+ self.addcomplex(self.S2L, amp),
195
+ self.j1,
196
+ self.j2,
197
+ backend=self.backend,
198
+ )
199
+
166
200
  def real(self):
167
- return scat(self.toreal(self.P00), \
168
- self.toreal(self.S0), \
169
- self.toreal(self.S1), \
170
- self.toreal(self.S2), \
171
- self.toreal(self.S2L), \
172
- self.j1,self.j2,backend=self.backend)
173
-
174
- def __radd__(self,other):
201
+ return scat(
202
+ self.toreal(self.P00),
203
+ self.toreal(self.S0),
204
+ self.toreal(self.S1),
205
+ self.toreal(self.S2),
206
+ self.toreal(self.S2L),
207
+ self.j1,
208
+ self.j2,
209
+ backend=self.backend,
210
+ )
211
+
212
+ def __radd__(self, other):
175
213
  return self.__add__(other)
176
214
 
177
- def __truediv__(self,other):
178
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
179
- isinstance(other, bool) or isinstance(other, scat)
180
-
215
+ def __truediv__(self, other):
216
+ assert (
217
+ isinstance(other, float)
218
+ or isinstance(other, np.float32)
219
+ or isinstance(other, int)
220
+ or isinstance(other, bool)
221
+ or isinstance(other, scat)
222
+ )
223
+
181
224
  if isinstance(other, scat):
182
- return scat(self.dodiv(self.P00, other.P00), \
183
- self.dodiv(self.S0, other.S0), \
184
- self.dodiv(self.S1, other.S1), \
185
- self.dodiv(self.S2, other.S2), \
186
- self.dodiv(self.S2L, other.S2L), \
187
- self.j1,self.j2,backend=self.backend)
225
+ return scat(
226
+ self.dodiv(self.P00, other.P00),
227
+ self.dodiv(self.S0, other.S0),
228
+ self.dodiv(self.S1, other.S1),
229
+ self.dodiv(self.S2, other.S2),
230
+ self.dodiv(self.S2L, other.S2L),
231
+ self.j1,
232
+ self.j2,
233
+ backend=self.backend,
234
+ )
188
235
  else:
189
- return scat((self.P00/ other), \
190
- (self.S0/ other), \
191
- (self.S1/ other), \
192
- (self.S2/ other), \
193
- (self.S2L/ other), \
194
- self.j1,self.j2,backend=self.backend)
195
-
196
-
197
- def __rtruediv__(self,other):
198
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
199
- isinstance(other, bool) or isinstance(other, scat)
200
-
236
+ return scat(
237
+ (self.P00 / other),
238
+ (self.S0 / other),
239
+ (self.S1 / other),
240
+ (self.S2 / other),
241
+ (self.S2L / other),
242
+ self.j1,
243
+ self.j2,
244
+ backend=self.backend,
245
+ )
246
+
247
+ def __rtruediv__(self, other):
248
+ assert (
249
+ isinstance(other, float)
250
+ or isinstance(other, np.float32)
251
+ or isinstance(other, int)
252
+ or isinstance(other, bool)
253
+ or isinstance(other, scat)
254
+ )
255
+
201
256
  if isinstance(other, scat):
202
- return scat(self.dodiv(other.P00, self.P00), \
203
- self.dodiv(other.S0 , self.S0), \
204
- self.dodiv(other.S1 , self.S1), \
205
- self.dodiv(other.S2 , self.S2), \
206
- self.dodiv(other.S2L , self.S2L), \
207
- self.j1,self.j2,backend=self.backend)
257
+ return scat(
258
+ self.dodiv(other.P00, self.P00),
259
+ self.dodiv(other.S0, self.S0),
260
+ self.dodiv(other.S1, self.S1),
261
+ self.dodiv(other.S2, self.S2),
262
+ self.dodiv(other.S2L, self.S2L),
263
+ self.j1,
264
+ self.j2,
265
+ backend=self.backend,
266
+ )
208
267
  else:
209
- return scat((other/ self.P00), \
210
- (other / self.S0), \
211
- (other / self.S1), \
212
- (other / self.S2), \
213
- (other / self.S2L), \
214
- self.j1,self.j2,backend=self.backend)
215
-
216
- def __sub__(self,other):
217
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
218
- isinstance(other, bool) or isinstance(other, scat)
219
-
268
+ return scat(
269
+ (other / self.P00),
270
+ (other / self.S0),
271
+ (other / self.S1),
272
+ (other / self.S2),
273
+ (other / self.S2L),
274
+ self.j1,
275
+ self.j2,
276
+ backend=self.backend,
277
+ )
278
+
279
+ def __sub__(self, other):
280
+ assert (
281
+ isinstance(other, float)
282
+ or isinstance(other, np.float32)
283
+ or isinstance(other, int)
284
+ or isinstance(other, bool)
285
+ or isinstance(other, scat)
286
+ )
287
+
220
288
  if isinstance(other, scat):
221
- return scat(self.domin(self.P00, other.P00), \
222
- self.domin(self.S0, other.S0), \
223
- self.domin(self.S1, other.S1), \
224
- self.domin(self.S2, other.S2), \
225
- self.domin(self.S2L, other.S2L), \
226
- self.j1,self.j2,backend=self.backend)
289
+ return scat(
290
+ self.domin(self.P00, other.P00),
291
+ self.domin(self.S0, other.S0),
292
+ self.domin(self.S1, other.S1),
293
+ self.domin(self.S2, other.S2),
294
+ self.domin(self.S2L, other.S2L),
295
+ self.j1,
296
+ self.j2,
297
+ backend=self.backend,
298
+ )
227
299
  else:
228
- return scat((self.P00- other), \
229
- (self.S0- other), \
230
- (self.S1- other), \
231
- (self.S2- other), \
232
- (self.S2L- other), \
233
- self.j1,self.j2,backend=self.backend)
234
-
235
- def __rsub__(self,other):
236
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
237
- isinstance(other, bool) or isinstance(other, scat)
238
-
300
+ return scat(
301
+ (self.P00 - other),
302
+ (self.S0 - other),
303
+ (self.S1 - other),
304
+ (self.S2 - other),
305
+ (self.S2L - other),
306
+ self.j1,
307
+ self.j2,
308
+ backend=self.backend,
309
+ )
310
+
311
+ def __rsub__(self, other):
312
+ assert (
313
+ isinstance(other, float)
314
+ or isinstance(other, np.float32)
315
+ or isinstance(other, int)
316
+ or isinstance(other, bool)
317
+ or isinstance(other, scat)
318
+ )
319
+
239
320
  if isinstance(other, scat):
240
- return scat(self.domin(other.P00,self.P00), \
241
- self.domin(other.S0, self.S0), \
242
- self.domin(other.S1, self.S1), \
243
- self.domin(other.S2, self.S2), \
244
- self.domin(other.S2L, self.S2L), \
245
- self.j1,self.j2,backend=self.backend)
321
+ return scat(
322
+ self.domin(other.P00, self.P00),
323
+ self.domin(other.S0, self.S0),
324
+ self.domin(other.S1, self.S1),
325
+ self.domin(other.S2, self.S2),
326
+ self.domin(other.S2L, self.S2L),
327
+ self.j1,
328
+ self.j2,
329
+ backend=self.backend,
330
+ )
246
331
  else:
247
- return scat((other-self.P00), \
248
- (other-self.S0), \
249
- (other-self.S1), \
250
- (other-self.S2), \
251
- (other-self.S2L), \
252
- self.j1,self.j2,backend=self.backend)
253
-
254
- def __mul__(self,other):
255
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
256
- isinstance(other, bool) or isinstance(other, scat)
257
-
332
+ return scat(
333
+ (other - self.P00),
334
+ (other - self.S0),
335
+ (other - self.S1),
336
+ (other - self.S2),
337
+ (other - self.S2L),
338
+ self.j1,
339
+ self.j2,
340
+ backend=self.backend,
341
+ )
342
+
343
+ def __mul__(self, other):
344
+ assert (
345
+ isinstance(other, float)
346
+ or isinstance(other, np.float32)
347
+ or isinstance(other, int)
348
+ or isinstance(other, bool)
349
+ or isinstance(other, scat)
350
+ )
351
+
258
352
  if isinstance(other, scat):
259
- return scat(self.domult(self.P00, other.P00), \
260
- self.domult(self.S0, other.S0), \
261
- self.domult(self.S1, other.S1), \
262
- self.domult(self.S2, other.S2), \
263
- self.domult(self.S2L, other.S2L), \
264
- self.j1,self.j2,backend=self.backend)
353
+ return scat(
354
+ self.domult(self.P00, other.P00),
355
+ self.domult(self.S0, other.S0),
356
+ self.domult(self.S1, other.S1),
357
+ self.domult(self.S2, other.S2),
358
+ self.domult(self.S2L, other.S2L),
359
+ self.j1,
360
+ self.j2,
361
+ backend=self.backend,
362
+ )
265
363
  else:
266
- return scat((self.P00* other), \
267
- (self.S0* other), \
268
- (self.S1* other), \
269
- (self.S2* other), \
270
- (self.S2L* other), \
271
- self.j1,self.j2,backend=self.backend)
364
+ return scat(
365
+ (self.P00 * other),
366
+ (self.S0 * other),
367
+ (self.S1 * other),
368
+ (self.S2 * other),
369
+ (self.S2L * other),
370
+ self.j1,
371
+ self.j2,
372
+ backend=self.backend,
373
+ )
374
+
272
375
  def relu(self):
273
- return scat(self.backend.bk_relu(self.P00),
274
- self.backend.bk_relu(self.S0),
275
- self.backend.bk_relu(self.S1),
276
- self.backend.bk_relu(self.S2),
277
- self.backend.bk_relu(self.S2L), \
278
- self.j1,self.j2,backend=self.backend)
279
-
280
-
281
- def __rmul__(self,other):
282
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
283
- isinstance(other, bool) or isinstance(other, scat)
284
-
376
+ return scat(
377
+ self.backend.bk_relu(self.P00),
378
+ self.backend.bk_relu(self.S0),
379
+ self.backend.bk_relu(self.S1),
380
+ self.backend.bk_relu(self.S2),
381
+ self.backend.bk_relu(self.S2L),
382
+ self.j1,
383
+ self.j2,
384
+ backend=self.backend,
385
+ )
386
+
387
+ def __rmul__(self, other):
388
+ assert (
389
+ isinstance(other, float)
390
+ or isinstance(other, np.float32)
391
+ or isinstance(other, int)
392
+ or isinstance(other, bool)
393
+ or isinstance(other, scat)
394
+ )
395
+
285
396
  if isinstance(other, scat):
286
- return scat(self.domult(self.P00, other.P00), \
287
- self.domult(self.S0, other.S0), \
288
- self.domult(self.S1, other.S1), \
289
- self.domult(self.S2, other.S2), \
290
- self.domult(self.S2L, other.S2L), \
291
- self.j1,self.j2,backend=self.backend)
397
+ return scat(
398
+ self.domult(self.P00, other.P00),
399
+ self.domult(self.S0, other.S0),
400
+ self.domult(self.S1, other.S1),
401
+ self.domult(self.S2, other.S2),
402
+ self.domult(self.S2L, other.S2L),
403
+ self.j1,
404
+ self.j2,
405
+ backend=self.backend,
406
+ )
292
407
  else:
293
- return scat((self.P00* other), \
294
- (self.S0* other), \
295
- (self.S1* other), \
296
- (self.S2* other), \
297
- (self.S2L* other), \
298
- self.j1,self.j2,backend=self.backend)
299
-
300
- def l1_abs(self,x):
301
- y=self.get_np(x)
408
+ return scat(
409
+ (self.P00 * other),
410
+ (self.S0 * other),
411
+ (self.S1 * other),
412
+ (self.S2 * other),
413
+ (self.S2L * other),
414
+ self.j1,
415
+ self.j2,
416
+ backend=self.backend,
417
+ )
418
+
419
+ def l1_abs(self, x):
420
+ y = self.get_np(x)
302
421
  if self.backend.bk_is_complex(y):
303
- tmp=y.real*y.real+y.imag*y.imag
304
- tmp=np.sign(tmp)*np.sqrt(np.fabs(tmp))
305
- y=tmp
306
-
307
- return(y)
308
-
309
- def plot(self,name=None,hold=True,color='blue',lw=1,legend=True):
422
+ tmp = y.real * y.real + y.imag * y.imag
423
+ tmp = np.sign(tmp) * np.sqrt(np.fabs(tmp))
424
+ y = tmp
425
+
426
+ return y
427
+
428
+ def plot(self, name=None, hold=True, color="blue", lw=1, legend=True):
310
429
 
311
430
  import matplotlib.pyplot as plt
312
431
 
313
- j1,j2=self.get_j_idx()
314
-
432
+ j1, j2 = self.get_j_idx()
433
+
315
434
  if name is None:
316
- name=''
435
+ name = ""
317
436
 
318
437
  if hold:
319
- plt.figure(figsize=(16,8))
320
-
321
- test=None
438
+ plt.figure(figsize=(16, 8))
439
+
440
+ test = None
322
441
  plt.subplot(2, 2, 1)
323
- tmp=abs(self.get_np(self.S1))
324
- if len(tmp.shape)==4:
442
+ tmp = abs(self.get_np(self.S1))
443
+ if len(tmp.shape) == 4:
325
444
  for k in range(tmp.shape[3]):
326
445
  for i1 in range(tmp.shape[0]):
327
446
  for i2 in range(tmp.shape[1]):
328
447
  if test is None:
329
- test=1
330
- plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
448
+ test = 1
449
+ plt.plot(
450
+ tmp[i1, i2, :, k],
451
+ color=color,
452
+ label=r"%s $S_{1}$" % (name),
453
+ lw=lw,
454
+ )
331
455
  else:
332
- plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
456
+ plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
333
457
  else:
334
458
  for k in range(tmp.shape[2]):
335
459
  for i1 in range(tmp.shape[0]):
336
460
  if test is None:
337
- test=1
338
- plt.plot(tmp[i1,:,k],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
461
+ test = 1
462
+ plt.plot(
463
+ tmp[i1, :, k],
464
+ color=color,
465
+ label=r"%s $S_{1}$" % (name),
466
+ lw=lw,
467
+ )
339
468
  else:
340
- plt.plot(tmp[i1,:,k],color=color, lw=lw)
341
- plt.yscale('log')
342
- plt.ylabel('S1')
343
- plt.xlabel(r'$j_{1}$')
469
+ plt.plot(tmp[i1, :, k], color=color, lw=lw)
470
+ plt.yscale("log")
471
+ plt.ylabel("S1")
472
+ plt.xlabel(r"$j_{1}$")
344
473
  plt.legend()
345
474
 
346
- test=None
475
+ test = None
347
476
  plt.subplot(2, 2, 2)
348
- tmp=abs(self.get_np(self.P00))
349
- if len(tmp.shape)==4:
477
+ tmp = abs(self.get_np(self.P00))
478
+ if len(tmp.shape) == 4:
350
479
  for k in range(tmp.shape[3]):
351
480
  for i1 in range(tmp.shape[0]):
352
481
  for i2 in range(tmp.shape[1]):
353
482
  if test is None:
354
- test=1
355
- plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
483
+ test = 1
484
+ plt.plot(
485
+ tmp[i1, i2, :, k],
486
+ color=color,
487
+ label=r"%s $P_{00}$" % (name),
488
+ lw=lw,
489
+ )
356
490
  else:
357
- plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
491
+ plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
358
492
  else:
359
493
  for k in range(tmp.shape[2]):
360
494
  for i1 in range(tmp.shape[0]):
361
495
  if test is None:
362
- test=1
363
- plt.plot(tmp[i1,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
496
+ test = 1
497
+ plt.plot(
498
+ tmp[i1, :, k],
499
+ color=color,
500
+ label=r"%s $P_{00}$" % (name),
501
+ lw=lw,
502
+ )
364
503
  else:
365
- plt.plot(tmp[i1,:,k],color=color, lw=lw)
366
- plt.yscale('log')
367
- plt.ylabel('P00')
368
- plt.xlabel(r'$j_{1}$')
504
+ plt.plot(tmp[i1, :, k], color=color, lw=lw)
505
+ plt.yscale("log")
506
+ plt.ylabel("P00")
507
+ plt.xlabel(r"$j_{1}$")
369
508
  plt.legend()
370
-
371
- ax1=plt.subplot(2, 2, 3)
509
+
510
+ ax1 = plt.subplot(2, 2, 3)
372
511
  ax2 = ax1.twiny()
373
- n=0
374
- tmp=abs(self.get_np(self.S2))
375
- lname=r'%s $S_{2}$' % (name)
376
- ax1.set_ylabel(r'$S_{2}$ [L1 norm]')
377
- test=None
378
- tabx=[]
379
- tabnx=[]
380
- tab2x=[]
381
- tab2nx=[]
382
- if len(tmp.shape)==5:
512
+ n = 0
513
+ tmp = abs(self.get_np(self.S2))
514
+ lname = r"%s $S_{2}$" % (name)
515
+ ax1.set_ylabel(r"$S_{2}$ [L1 norm]")
516
+ test = None
517
+ tabx = []
518
+ tabnx = []
519
+ tab2x = []
520
+ tab2nx = []
521
+ if len(tmp.shape) == 5:
383
522
  for i0 in range(tmp.shape[0]):
384
523
  for i1 in range(tmp.shape[1]):
385
- for i2 in range(j1.max()+1):
524
+ for i2 in range(j1.max() + 1):
386
525
  for i3 in range(tmp.shape[3]):
387
526
  for i4 in range(tmp.shape[4]):
388
- if j2[j1==i2].shape[0]==1:
389
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4],'.', \
390
- color=color, lw=lw)
527
+ if j2[j1 == i2].shape[0] == 1:
528
+ ax1.plot(
529
+ j2[j1 == i2] + n,
530
+ tmp[i0, i1, j1 == i2, i3, i4],
531
+ ".",
532
+ color=color,
533
+ lw=lw,
534
+ )
391
535
  else:
392
536
  if legend and test is None:
393
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
394
- color=color, label=lname, lw=lw)
395
- test=1
396
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
397
- color=color, lw=lw)
398
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
399
- tabx=tabx+[k+n for k in j2[j1==i2]]
400
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
401
- tab2nx=tab2nx+['%d'%(i2)]
402
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
403
- n=n+j2[j1==i2].shape[0]-1
537
+ ax1.plot(
538
+ j2[j1 == i2] + n,
539
+ tmp[i0, i1, j1 == i2, i3, i4],
540
+ color=color,
541
+ label=lname,
542
+ lw=lw,
543
+ )
544
+ test = 1
545
+ ax1.plot(
546
+ j2[j1 == i2] + n,
547
+ tmp[i0, i1, j1 == i2, i3, i4],
548
+ color=color,
549
+ lw=lw,
550
+ )
551
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
552
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
553
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
554
+ tab2nx = tab2nx + ["%d" % (i2)]
555
+ ax1.axvline(
556
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
557
+ )
558
+ n = n + j2[j1 == i2].shape[0] - 1
404
559
  else:
405
560
  for i0 in range(tmp.shape[0]):
406
- for i2 in range(j1.max()+1):
561
+ for i2 in range(j1.max() + 1):
407
562
  for i3 in range(tmp.shape[2]):
408
563
  for i4 in range(tmp.shape[3]):
409
- if j2[j1==i2].shape[0]==1:
410
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4],'.', \
411
- color=color, lw=lw)
564
+ if j2[j1 == i2].shape[0] == 1:
565
+ ax1.plot(
566
+ j2[j1 == i2] + n,
567
+ tmp[i0, j1 == i2, i3, i4],
568
+ ".",
569
+ color=color,
570
+ lw=lw,
571
+ )
412
572
  else:
413
573
  if legend and test is None:
414
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
415
- color=color, label=lname, lw=lw)
416
- test=1
417
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
418
- color=color, lw=lw)
419
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
420
- tabx=tabx+[k+n for k in j2[j1==i2]]
421
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
422
- tab2nx=tab2nx+['%d'%(i2)]
423
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
424
- n=n+j2[j1==i2].shape[0]-1
425
- plt.yscale('log')
426
- ax1.set_xlim(0,n+2)
574
+ ax1.plot(
575
+ j2[j1 == i2] + n,
576
+ tmp[i0, j1 == i2, i3, i4],
577
+ color=color,
578
+ label=lname,
579
+ lw=lw,
580
+ )
581
+ test = 1
582
+ ax1.plot(
583
+ j2[j1 == i2] + n,
584
+ tmp[i0, j1 == i2, i3, i4],
585
+ color=color,
586
+ lw=lw,
587
+ )
588
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
589
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
590
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
591
+ tab2nx = tab2nx + ["%d" % (i2)]
592
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
593
+ n = n + j2[j1 == i2].shape[0] - 1
594
+ plt.yscale("log")
595
+ ax1.set_xlim(0, n + 2)
427
596
  ax1.set_xticks(tabx)
428
- ax1.set_xticklabels(tabnx,fontsize=6)
429
- ax1.set_xlabel(r"$j_{2}$ ",fontsize=6)
430
-
597
+ ax1.set_xticklabels(tabnx, fontsize=6)
598
+ ax1.set_xlabel(r"$j_{2}$ ", fontsize=6)
599
+
431
600
  # Move twinned axis ticks and label from top to bottom
432
601
  ax2.xaxis.set_ticks_position("bottom")
433
602
  ax2.xaxis.set_label_position("bottom")
@@ -435,7 +604,7 @@ class scat:
435
604
  # Offset the twin axis below the host
436
605
  ax2.spines["bottom"].set_position(("axes", -0.15))
437
606
 
438
- # Turn on the frame for the twin axis, but then hide all
607
+ # Turn on the frame for the twin axis, but then hide all
439
608
  # but the bottom spine
440
609
  ax2.set_frame_on(True)
441
610
  ax2.patch.set_visible(False)
@@ -443,72 +612,102 @@ class scat:
443
612
  for sp in ax2.spines.values():
444
613
  sp.set_visible(False)
445
614
  ax2.spines["bottom"].set_visible(True)
446
- ax2.set_xlim(0,n+2)
615
+ ax2.set_xlim(0, n + 2)
447
616
  ax2.set_xticks(tab2x)
448
- ax2.set_xticklabels(tab2nx,fontsize=6)
449
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
617
+ ax2.set_xticklabels(tab2nx, fontsize=6)
618
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
450
619
  ax1.legend(frameon=0)
451
-
452
- ax1=plt.subplot(2, 2, 4)
620
+
621
+ ax1 = plt.subplot(2, 2, 4)
453
622
  ax2 = ax1.twiny()
454
- n=0
455
- tmp=abs(self.get_np(self.S2L))
456
- lname=r'%s $S2_{2}$' % (name)
457
- ax1.set_ylabel(r'$S_{2}$ [L2 norm]')
458
- test=None
459
- tabx=[]
460
- tabnx=[]
461
- tab2x=[]
462
- tab2nx=[]
463
- if len(tmp.shape)==5:
623
+ n = 0
624
+ tmp = abs(self.get_np(self.S2L))
625
+ lname = r"%s $S2_{2}$" % (name)
626
+ ax1.set_ylabel(r"$S_{2}$ [L2 norm]")
627
+ test = None
628
+ tabx = []
629
+ tabnx = []
630
+ tab2x = []
631
+ tab2nx = []
632
+ if len(tmp.shape) == 5:
464
633
  for i0 in range(tmp.shape[0]):
465
634
  for i1 in range(tmp.shape[1]):
466
- for i2 in range(j1.max()+1):
635
+ for i2 in range(j1.max() + 1):
467
636
  for i3 in range(tmp.shape[3]):
468
637
  for i4 in range(tmp.shape[4]):
469
- if j2[j1==i2].shape[0]==1:
470
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4],'.', \
471
- color=color, lw=lw)
638
+ if j2[j1 == i2].shape[0] == 1:
639
+ ax1.plot(
640
+ j2[j1 == i2] + n,
641
+ tmp[i0, i1, j1 == i2, i3, i4],
642
+ ".",
643
+ color=color,
644
+ lw=lw,
645
+ )
472
646
  else:
473
647
  if legend and test is None:
474
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
475
- color=color, label=lname, lw=lw)
476
- test=1
477
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
478
- color=color, lw=lw)
479
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
480
- tabx=tabx+[k+n for k in j2[j1==i2]]
481
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
482
- tab2nx=tab2nx+['%d'%(i2)]
483
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
484
- n=n+j2[j1==i2].shape[0]-1
648
+ ax1.plot(
649
+ j2[j1 == i2] + n,
650
+ tmp[i0, i1, j1 == i2, i3, i4],
651
+ color=color,
652
+ label=lname,
653
+ lw=lw,
654
+ )
655
+ test = 1
656
+ ax1.plot(
657
+ j2[j1 == i2] + n,
658
+ tmp[i0, i1, j1 == i2, i3, i4],
659
+ color=color,
660
+ lw=lw,
661
+ )
662
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
663
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
664
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
665
+ tab2nx = tab2nx + ["%d" % (i2)]
666
+ ax1.axvline(
667
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
668
+ )
669
+ n = n + j2[j1 == i2].shape[0] - 1
485
670
  else:
486
671
  for i0 in range(tmp.shape[0]):
487
- for i2 in range(j1.max()+1):
672
+ for i2 in range(j1.max() + 1):
488
673
  for i3 in range(tmp.shape[2]):
489
674
  for i4 in range(tmp.shape[3]):
490
- if j2[j1==i2].shape[0]==1:
491
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4],'.', \
492
- color=color, lw=lw)
675
+ if j2[j1 == i2].shape[0] == 1:
676
+ ax1.plot(
677
+ j2[j1 == i2] + n,
678
+ tmp[i0, j1 == i2, i3, i4],
679
+ ".",
680
+ color=color,
681
+ lw=lw,
682
+ )
493
683
  else:
494
684
  if legend and test is None:
495
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
496
- color=color, label=lname, lw=lw)
497
- test=1
498
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
499
- color=color, lw=lw)
500
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
501
- tabx=tabx+[k+n for k in j2[j1==i2]]
502
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
503
- tab2nx=tab2nx+['%d'%(i2)]
504
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
505
- n=n+j2[j1==i2].shape[0]-1
506
- plt.yscale('log')
507
- ax1.set_xlim(-1,n+3)
685
+ ax1.plot(
686
+ j2[j1 == i2] + n,
687
+ tmp[i0, j1 == i2, i3, i4],
688
+ color=color,
689
+ label=lname,
690
+ lw=lw,
691
+ )
692
+ test = 1
693
+ ax1.plot(
694
+ j2[j1 == i2] + n,
695
+ tmp[i0, j1 == i2, i3, i4],
696
+ color=color,
697
+ lw=lw,
698
+ )
699
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
700
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
701
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
702
+ tab2nx = tab2nx + ["%d" % (i2)]
703
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
704
+ n = n + j2[j1 == i2].shape[0] - 1
705
+ plt.yscale("log")
706
+ ax1.set_xlim(-1, n + 3)
508
707
  ax1.set_xticks(tabx)
509
- ax1.set_xticklabels(tabnx,fontsize=6)
510
- ax1.set_xlabel(r"$j_{2}$",fontsize=6)
511
-
708
+ ax1.set_xticklabels(tabnx, fontsize=6)
709
+ ax1.set_xlabel(r"$j_{2}$", fontsize=6)
710
+
512
711
  # Move twinned axis ticks and label from top to bottom
513
712
  ax2.xaxis.set_ticks_position("bottom")
514
713
  ax2.xaxis.set_label_position("bottom")
@@ -516,7 +715,7 @@ class scat:
516
715
  # Offset the twin axis below the host
517
716
  ax2.spines["bottom"].set_position(("axes", -0.15))
518
717
 
519
- # Turn on the frame for the twin axis, but then hide all
718
+ # Turn on the frame for the twin axis, but then hide all
520
719
  # but the bottom spine
521
720
  ax2.set_frame_on(True)
522
721
  ax2.patch.set_visible(False)
@@ -524,252 +723,351 @@ class scat:
524
723
  for sp in ax2.spines.values():
525
724
  sp.set_visible(False)
526
725
  ax2.spines["bottom"].set_visible(True)
527
- ax2.set_xlim(0,n+3)
726
+ ax2.set_xlim(0, n + 3)
528
727
  ax2.set_xticks(tab2x)
529
- ax2.set_xticklabels(tab2nx,fontsize=6)
530
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
728
+ ax2.set_xticklabels(tab2nx, fontsize=6)
729
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
531
730
  ax1.legend(frameon=0)
532
-
533
- def save(self,filename):
534
- outlist=[self.get_S0().numpy(), \
535
- self.get_S1().numpy(), \
536
- self.get_S2().numpy(), \
537
- self.get_S2L().numpy(), \
538
- self.get_P00().numpy(), \
539
- self.j1, \
540
- self.j2]
541
-
542
- myout=open("%s.pkl"%(filename),"wb")
543
- pickle.dump(outlist,myout)
731
+
732
+ def save(self, filename):
733
+ outlist = [
734
+ self.get_S0().numpy(),
735
+ self.get_S1().numpy(),
736
+ self.get_S2().numpy(),
737
+ self.get_S2L().numpy(),
738
+ self.get_P00().numpy(),
739
+ self.j1,
740
+ self.j2,
741
+ ]
742
+
743
+ myout = open("%s.pkl" % (filename), "wb")
744
+ pickle.dump(outlist, myout)
544
745
  myout.close()
545
746
 
546
-
547
- def read(self,filename):
548
-
549
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
550
- return scat(outlist[4],outlist[0],outlist[1],outlist[2],outlist[3],outlist[5],outlist[6],backend=bk.foscat_backend('numpy'))
551
-
552
- def get_np(self,x):
747
+ def read(self, filename):
748
+
749
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
750
+ return scat(
751
+ outlist[4],
752
+ outlist[0],
753
+ outlist[1],
754
+ outlist[2],
755
+ outlist[3],
756
+ outlist[5],
757
+ outlist[6],
758
+ backend=bk.foscat_backend("numpy"),
759
+ )
760
+
761
+ def get_np(self, x):
553
762
  if isinstance(x, np.ndarray):
554
763
  return x
555
764
  else:
556
765
  return x.numpy()
557
766
 
558
767
  def std(self):
559
- return np.sqrt(((abs(self.get_np(self.S0)).std())**2+ \
560
- (abs(self.get_np(self.S1)).std())**2+ \
561
- (abs(self.get_np(self.S2)).std())**2+ \
562
- (abs(self.get_np(self.S2L)).std())**2+ \
563
- (abs(self.get_np(self.P00)).std())**2)/4)
768
+ return np.sqrt(
769
+ (
770
+ (abs(self.get_np(self.S0)).std()) ** 2
771
+ + (abs(self.get_np(self.S1)).std()) ** 2
772
+ + (abs(self.get_np(self.S2)).std()) ** 2
773
+ + (abs(self.get_np(self.S2L)).std()) ** 2
774
+ + (abs(self.get_np(self.P00)).std()) ** 2
775
+ )
776
+ / 4
777
+ )
564
778
 
565
779
  def mean(self):
566
- return abs(self.get_np(self.S0).mean()+ \
567
- self.get_np(self.S1).mean()+ \
568
- self.get_np(self.S2).mean()+ \
569
- self.get_np(self.S2L).mean()+ \
570
- self.get_np(self.P00).mean())/3
780
+ return (
781
+ abs(
782
+ self.get_np(self.S0).mean()
783
+ + self.get_np(self.S1).mean()
784
+ + self.get_np(self.S2).mean()
785
+ + self.get_np(self.S2L).mean()
786
+ + self.get_np(self.P00).mean()
787
+ )
788
+ / 3
789
+ )
571
790
 
572
791
  def sqrt(self):
573
792
 
793
+ s0 = self.backend.bk_sqrt(self.S0)
794
+ s1 = self.backend.bk_sqrt(self.S1)
795
+ p00 = self.backend.bk_sqrt(self.P00)
796
+ s2 = self.backend.bk_sqrt(self.S2)
797
+ s2L = self.backend.bk_sqrt(self.S2L)
574
798
 
575
- s0 =self.backend.bk_sqrt(self.S0)
576
- s1 =self.backend.bk_sqrt(self.S1)
577
- p00=self.backend.bk_sqrt(self.P00)
578
- s2 =self.backend.bk_sqrt(self.S2)
579
- s2L=self.backend.bk_sqrt(self.S2L)
580
-
581
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
582
-
799
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
583
800
 
584
801
  def L1(self):
585
802
 
803
+ s0 = self.backend.bk_L1(self.S0)
804
+ s1 = self.backend.bk_L1(self.S1)
805
+ p00 = self.backend.bk_L1(self.P00)
806
+ s2 = self.backend.bk_L1(self.S2)
807
+ s2L = self.backend.bk_L1(self.S2L)
808
+
809
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
586
810
 
587
- s0 =self.backend.bk_L1(self.S0)
588
- s1 =self.backend.bk_L1(self.S1)
589
- p00=self.backend.bk_L1(self.P00)
590
- s2 =self.backend.bk_L1(self.S2)
591
- s2L=self.backend.bk_L1(self.S2L)
592
-
593
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
594
-
595
811
  def square_comp(self):
596
812
 
813
+ s0 = self.backend.bk_square_comp(self.S0)
814
+ s1 = self.backend.bk_square_comp(self.S1)
815
+ p00 = self.backend.bk_square_comp(self.P00)
816
+ s2 = self.backend.bk_square_comp(self.S2)
817
+ s2L = self.backend.bk_square_comp(self.S2L)
818
+
819
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
597
820
 
598
- s0 =self.backend.bk_square_comp(self.S0)
599
- s1 =self.backend.bk_square_comp(self.S1)
600
- p00=self.backend.bk_square_comp(self.P00)
601
- s2 =self.backend.bk_square_comp(self.S2)
602
- s2L=self.backend.bk_square_comp(self.S2L)
603
-
604
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
605
-
606
- def iso_mean(self,repeat=False):
607
- shape=list(self.S2.shape)
608
- norient=self.S1.shape[2]
821
+ def iso_mean(self, repeat=False):
822
+ shape = list(self.S2.shape)
823
+ norient = self.S1.shape[2]
609
824
 
610
- S1 = self.backend.bk_reduce_mean(self.S1,2)
825
+ S1 = self.backend.bk_reduce_mean(self.S1, 2)
611
826
  if repeat:
612
- S1=self.backend.bk_reshape(self.backend.bk_repeat(S1,shape[2],1),self.S1.shape)
827
+ S1 = self.backend.bk_reshape(
828
+ self.backend.bk_repeat(S1, shape[2], 1), self.S1.shape
829
+ )
613
830
  else:
614
- S1=self.backend.bk_expand_dims(S1,-1)
615
-
831
+ S1 = self.backend.bk_expand_dims(S1, -1)
616
832
 
617
- P00 = self.backend.bk_reduce_mean(self.P00,2)
833
+ P00 = self.backend.bk_reduce_mean(self.P00, 2)
618
834
  if repeat:
619
- P00=self.backend.bk_reshape(self.backend.bk_repeat(P00,shape[2],1),self.S1.shape)
835
+ P00 = self.backend.bk_reshape(
836
+ self.backend.bk_repeat(P00, shape[2], 1), self.S1.shape
837
+ )
620
838
  else:
621
- P00=self.backend.bk_expand_dims(P00,-1)
839
+ P00 = self.backend.bk_expand_dims(P00, -1)
622
840
 
623
841
  if norient not in self.backend._iso_orient:
624
842
  self.backend.calc_iso_orient(norient)
625
-
843
+
626
844
  if self.backend.bk_is_complex(self.S2):
627
- lmat = self.backend._iso_orient_C[norient]
845
+ lmat = self.backend._iso_orient_C[norient]
628
846
  lmat_T = self.backend._iso_orient_C_T[norient]
629
847
  else:
630
- lmat = self.backend._iso_orient[norient]
848
+ lmat = self.backend._iso_orient[norient]
631
849
  lmat_T = self.backend._iso_orient_T[norient]
632
-
633
- S2=self.backend.bk_reshape(
634
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2,[shape[0],shape[1],norient*norient]),lmat),
635
- [shape[0],shape[1],norient])
636
- S2L=self.backend.bk_reshape(
637
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2L,[shape[0],shape[1],norient*norient]),lmat),
638
- [shape[0],shape[1],norient])
639
-
850
+
851
+ S2 = self.backend.bk_reshape(
852
+ self.backend.backend.matmul(
853
+ self.backend.bk_reshape(
854
+ self.S2, [shape[0], shape[1], norient * norient]
855
+ ),
856
+ lmat,
857
+ ),
858
+ [shape[0], shape[1], norient],
859
+ )
860
+ S2L = self.backend.bk_reshape(
861
+ self.backend.backend.matmul(
862
+ self.backend.bk_reshape(
863
+ self.S2L, [shape[0], shape[1], norient * norient]
864
+ ),
865
+ lmat,
866
+ ),
867
+ [shape[0], shape[1], norient],
868
+ )
869
+
640
870
  if repeat:
641
-
642
- S2=self.backend.bk_reshape(
643
- self.backend.backend.matmul(self.backend.bk_reshape(S2,[shape[0]*shape[1],norient]),lmat_T),
644
- self.S2.shape)
645
- S2L=self.backend.bk_reshape(
646
- self.backend.backend.matmul(self.backend.bk_reshape(S2L,[shape[0]*shape[1],norient]),lmat_T),
647
- self.S2.shape)
871
+
872
+ S2 = self.backend.bk_reshape(
873
+ self.backend.backend.matmul(
874
+ self.backend.bk_reshape(S2, [shape[0] * shape[1], norient]), lmat_T
875
+ ),
876
+ self.S2.shape,
877
+ )
878
+ S2L = self.backend.bk_reshape(
879
+ self.backend.backend.matmul(
880
+ self.backend.bk_reshape(S2L, [shape[0] * shape[1], norient]), lmat_T
881
+ ),
882
+ self.S2.shape,
883
+ )
648
884
  else:
649
- S2=self.backend.bk_expand_dims(S2,-1)
650
- S2L=self.backend.bk_expand_dims(S2L,-1)
885
+ S2 = self.backend.bk_expand_dims(S2, -1)
886
+ S2L = self.backend.bk_expand_dims(S2L, -1)
651
887
 
652
- return scat(P00,self.S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
888
+ return scat(P00, self.S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
653
889
 
654
-
655
- def fft_ang(self,nharm=1,imaginary=False):
656
- shape=list(self.S2.shape)
657
- norient=self.S1.shape[2]
890
+ def fft_ang(self, nharm=1, imaginary=False):
891
+ shape = list(self.S2.shape)
892
+ norient = self.S1.shape[2]
658
893
 
659
- nout=1+nharm
894
+ nout = 1 + nharm
660
895
  if imaginary:
661
- nout=1+nharm*2
662
-
663
- if (norient,nharm) not in self.backend._fft_1_orient:
664
- self.backend.calc_fft_orient(norient,nharm,imaginary)
665
-
896
+ nout = 1 + nharm * 2
897
+
898
+ if (norient, nharm) not in self.backend._fft_1_orient:
899
+ self.backend.calc_fft_orient(norient, nharm, imaginary)
900
+
666
901
  if self.backend.bk_is_complex(self.S1):
667
- lmat = self.backend._fft_1_orient_C[(norient,nharm,imaginary)]
902
+ lmat = self.backend._fft_1_orient_C[(norient, nharm, imaginary)]
668
903
  else:
669
- lmat = self.backend._fft_1_orient[(norient,nharm,imaginary)]
670
-
671
- S1=self.backend.bk_reshape(
672
- self.backend.backend.matmul(self.backend.bk_reshape(self.S1,[self.S1.shape[0],self.S1.shape[1],norient]),lmat),
673
- [self.S1.shape[0],self.S1.shape[1],nout])
674
-
675
- P00=self.backend.bk_reshape(
676
- self.backend.backend.matmul(self.backend.bk_reshape(self.P00,[self.S1.shape[0],self.S1.shape[1],norient]),lmat),
677
- [self.S1.shape[0],self.S1.shape[1],nout])
678
-
679
-
904
+ lmat = self.backend._fft_1_orient[(norient, nharm, imaginary)]
905
+
906
+ S1 = self.backend.bk_reshape(
907
+ self.backend.backend.matmul(
908
+ self.backend.bk_reshape(
909
+ self.S1, [self.S1.shape[0], self.S1.shape[1], norient]
910
+ ),
911
+ lmat,
912
+ ),
913
+ [self.S1.shape[0], self.S1.shape[1], nout],
914
+ )
915
+
916
+ P00 = self.backend.bk_reshape(
917
+ self.backend.backend.matmul(
918
+ self.backend.bk_reshape(
919
+ self.P00, [self.S1.shape[0], self.S1.shape[1], norient]
920
+ ),
921
+ lmat,
922
+ ),
923
+ [self.S1.shape[0], self.S1.shape[1], nout],
924
+ )
925
+
680
926
  if self.backend.bk_is_complex(self.S2):
681
- lmat = self.backend._fft_2_orient_C[(norient,nharm,imaginary)]
927
+ lmat = self.backend._fft_2_orient_C[(norient, nharm, imaginary)]
682
928
  else:
683
- lmat = self.backend._fft_2_orient[(norient,nharm,imaginary)]
684
-
685
- S2=self.backend.bk_reshape(
686
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2,[shape[0],shape[1],norient*norient]),lmat),
687
- [shape[0],shape[1],nout,nout])
688
- S2L=self.backend.bk_reshape(
689
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2L,[shape[0],shape[1],norient*norient]),lmat),
690
- [shape[0],shape[1],nout,nout])
691
-
692
- return scat(P00,self.S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
693
-
694
-
695
- def iso_std(self,repeat=False):
696
-
697
- val=(self-self.iso_mean(repeat=True)).square_comp()
929
+ lmat = self.backend._fft_2_orient[(norient, nharm, imaginary)]
930
+
931
+ S2 = self.backend.bk_reshape(
932
+ self.backend.backend.matmul(
933
+ self.backend.bk_reshape(
934
+ self.S2, [shape[0], shape[1], norient * norient]
935
+ ),
936
+ lmat,
937
+ ),
938
+ [shape[0], shape[1], nout, nout],
939
+ )
940
+ S2L = self.backend.bk_reshape(
941
+ self.backend.backend.matmul(
942
+ self.backend.bk_reshape(
943
+ self.S2L, [shape[0], shape[1], norient * norient]
944
+ ),
945
+ lmat,
946
+ ),
947
+ [shape[0], shape[1], nout, nout],
948
+ )
949
+
950
+ return scat(P00, self.S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
951
+
952
+ def iso_std(self, repeat=False):
953
+
954
+ val = (self - self.iso_mean(repeat=True)).square_comp()
698
955
  return (val.iso_mean(repeat=repeat)).L1()
699
956
 
700
957
  # ---------------------------------------------−---------
701
- def cleanval(self,x):
702
- x=x.numpy()
703
- x[np.isfinite(x)==False]=np.median(x[np.isfinite(x)])
958
+ def cleanval(self, x):
959
+ x = x.numpy()
960
+ x[~np.isfinite(x)] = np.median(x[np.isfinite(x)])
704
961
  return x
705
962
 
706
963
  def filter_inf(self):
707
- S1 = self.cleanval(self.S1)
708
- S0 = self.cleanval(self.S0)
964
+ S1 = self.cleanval(self.S1)
965
+ S0 = self.cleanval(self.S0)
709
966
  P00 = self.cleanval(self.P00)
710
- S2 = self.cleanval(self.S2)
967
+ S2 = self.cleanval(self.S2)
711
968
  S2L = self.cleanval(self.S2L)
712
969
 
713
- return scat(P00,S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
970
+ return scat(P00, S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
714
971
 
715
972
  # ---------------------------------------------−---------
716
- def interp(self,nscale,extend=False,constant=False,threshold=1E30,use_mask=False):
717
-
718
- if nscale+2>self.S1.shape[1]:
719
- print('Can not *interp* %d with a statistic described over %d'%(nscale,self.S1.shape[1]))
720
- return scat(self.P00,self.S0,self.S1,self.S2,self.S2L,self.j1,self.j2,backend=self.backend)
721
- if isinstance(self.S1,np.ndarray):
722
- s1=self.S1
723
- p0=self.P00
724
- s2=self.S2
725
- s2l=self.S2L
973
+ def interp(
974
+ self, nscale, extend=False, constant=False, threshold=1e30, use_mask=False
975
+ ):
976
+
977
+ if nscale + 2 > self.S1.shape[1]:
978
+ print(
979
+ "Can not *interp* %d with a statistic described over %d"
980
+ % (nscale, self.S1.shape[1])
981
+ )
982
+ return scat(
983
+ self.P00,
984
+ self.S0,
985
+ self.S1,
986
+ self.S2,
987
+ self.S2L,
988
+ self.j1,
989
+ self.j2,
990
+ backend=self.backend,
991
+ )
992
+ if isinstance(self.S1, np.ndarray):
993
+ s1 = self.S1
994
+ p0 = self.P00
995
+ s2 = self.S2
996
+ s2l = self.S2L
726
997
  else:
727
- s1=self.S1.numpy()
728
- p0=self.P00.numpy()
729
- s2=self.S2.numpy()
730
- s2l=self.S2L.numpy()
731
-
732
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
733
-
734
- if isinstance(threshold,scat):
735
- if isinstance(threshold.S1,np.ndarray):
736
- s1th=threshold.S1
737
- p0th=threshold.P00
738
- s2th=threshold.S2
739
- s2lth=threshold.S2L
998
+ s1 = self.S1.numpy()
999
+ p0 = self.P00.numpy()
1000
+ s2 = self.S2.numpy()
1001
+ s2l = self.S2L.numpy()
1002
+
1003
+ print(s1.sum(), p0.sum(), s2.sum(), s2l.sum())
1004
+
1005
+ if isinstance(threshold, scat):
1006
+ if isinstance(threshold.S1, np.ndarray):
1007
+ s1th = threshold.S1
1008
+ p0th = threshold.P00
1009
+ s2th = threshold.S2
1010
+ s2lth = threshold.S2L
740
1011
  else:
741
- s1th=threshold.S1.numpy()
742
- p0th=threshold.P00.numpy()
743
- s2th=threshold.S2.numpy()
744
- s2lth=threshold.S2L.numpy()
1012
+ s1th = threshold.S1.numpy()
1013
+ p0th = threshold.P00.numpy()
1014
+ s2th = threshold.S2.numpy()
1015
+ s2lth = threshold.S2L.numpy()
745
1016
  else:
746
- s1th=threshold+0*s1
747
- p0th=threshold+0*p0
748
- s2th=threshold+0*s2
749
- s2lth=threshold+0*s2l
1017
+ s1th = threshold + 0 * s1
1018
+ p0th = threshold + 0 * p0
1019
+ s2th = threshold + 0 * s2
1020
+ s2lth = threshold + 0 * s2l
750
1021
 
751
1022
  for k in range(nscale):
752
1023
  if constant:
753
- s1[:,nscale-1-k,:]=s1[:,nscale-k,:]
754
- p0[:,nscale-1-k,:]=p0[:,nscale-k,:]
1024
+ s1[:, nscale - 1 - k, :] = s1[:, nscale - k, :]
1025
+ p0[:, nscale - 1 - k, :] = p0[:, nscale - k, :]
755
1026
  else:
756
- idx=np.where((s1[:,nscale+1-k,:]>0)*(s1[:,nscale+2-k,:]>0)*(s1[:,nscale-k,:]<s1th[:,nscale-k,:]))
757
- if len(idx[0])>0:
758
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(s1[idx[0],nscale+1-k,idx[1]])-2*np.log(s1[idx[0],nscale+2-k,idx[1]]))
759
- idx=np.where((s1[:,nscale-k,:]>0)*(s1[:,nscale+1-k,:]>0)*(s1[:,nscale-1-k,:]<s1th[:,nscale-1-k,:]))
760
- if len(idx[0])>0:
761
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(s1[idx[0],nscale-k,idx[1]])-np.log(s1[idx[0],nscale+1-k,idx[1]]))
762
-
763
- idx=np.where((p0[:,nscale+1-k,:]>0)*(p0[:,nscale+2-k,:]>0)*(p0[:,nscale-k,:]<p0th[:,nscale-k,:]))
764
- if len(idx[0])>0:
765
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(p0[idx[0],nscale+1-k,idx[1]])-2*np.log(p0[idx[0],nscale+2-k,idx[1]]))
766
-
767
- idx=np.where((p0[:,nscale-k,:]>0)*(p0[:,nscale+1-k,:]>0)*(p0[:,nscale-1-k,:]<p0th[:,nscale-1-k,:]))
768
- if len(idx[0])>0:
769
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(p0[idx[0],nscale-k,idx[1]])-np.log(p0[idx[0],nscale+1-k,idx[1]]))
770
-
771
-
772
- j1,j2=self.get_j_idx()
1027
+ idx = np.where(
1028
+ (s1[:, nscale + 1 - k, :] > 0)
1029
+ * (s1[:, nscale + 2 - k, :] > 0)
1030
+ * (s1[:, nscale - k, :] < s1th[:, nscale - k, :])
1031
+ )
1032
+ if len(idx[0]) > 0:
1033
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1034
+ 3 * np.log(s1[idx[0], nscale + 1 - k, idx[1]])
1035
+ - 2 * np.log(s1[idx[0], nscale + 2 - k, idx[1]])
1036
+ )
1037
+ idx = np.where(
1038
+ (s1[:, nscale - k, :] > 0)
1039
+ * (s1[:, nscale + 1 - k, :] > 0)
1040
+ * (s1[:, nscale - 1 - k, :] < s1th[:, nscale - 1 - k, :])
1041
+ )
1042
+ if len(idx[0]) > 0:
1043
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1044
+ 2 * np.log(s1[idx[0], nscale - k, idx[1]])
1045
+ - np.log(s1[idx[0], nscale + 1 - k, idx[1]])
1046
+ )
1047
+
1048
+ idx = np.where(
1049
+ (p0[:, nscale + 1 - k, :] > 0)
1050
+ * (p0[:, nscale + 2 - k, :] > 0)
1051
+ * (p0[:, nscale - k, :] < p0th[:, nscale - k, :])
1052
+ )
1053
+ if len(idx[0]) > 0:
1054
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1055
+ 3 * np.log(p0[idx[0], nscale + 1 - k, idx[1]])
1056
+ - 2 * np.log(p0[idx[0], nscale + 2 - k, idx[1]])
1057
+ )
1058
+
1059
+ idx = np.where(
1060
+ (p0[:, nscale - k, :] > 0)
1061
+ * (p0[:, nscale + 1 - k, :] > 0)
1062
+ * (p0[:, nscale - 1 - k, :] < p0th[:, nscale - 1 - k, :])
1063
+ )
1064
+ if len(idx[0]) > 0:
1065
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1066
+ 2 * np.log(p0[idx[0], nscale - k, idx[1]])
1067
+ - np.log(p0[idx[0], nscale + 1 - k, idx[1]])
1068
+ )
1069
+
1070
+ j1, j2 = self.get_j_idx()
773
1071
 
774
1072
  for k in range(nscale):
775
1073
 
@@ -782,669 +1080,951 @@ class scat:
782
1080
  s2l[:,i0]=np.exp(2*np.log(s2l[:,i1])-np.log(s2l[:,i2]))
783
1081
  """
784
1082
 
785
- for l in range(nscale-k):
786
- i0=np.where((j1==nscale-1-k-l)*(j2==nscale-1-k))[0]
787
- i1=np.where((j1==nscale-1-k-l)*(j2==nscale -k))[0]
788
- i2=np.where((j1==nscale-1-k-l)*(j2==nscale+1-k))[0]
789
- i3=np.where((j1==nscale-1-k-l)*(j2==nscale+2-k))[0]
790
-
1083
+ for l_scale in range(nscale - k):
1084
+ i0 = np.where(
1085
+ (j1 == nscale - 1 - k - l_scale) * (j2 == nscale - 1 - k)
1086
+ )[0]
1087
+ i1 = np.where((j1 == nscale - 1 - k - l_scale) * (j2 == nscale - k))[0]
1088
+ i2 = np.where(
1089
+ (j1 == nscale - 1 - k - l_scale) * (j2 == nscale + 1 - k)
1090
+ )[0]
1091
+ i3 = np.where(
1092
+ (j1 == nscale - 1 - k - l_scale) * (j2 == nscale + 2 - k)
1093
+ )[0]
1094
+
791
1095
  if constant:
792
- s2[:,i0]=s2[:,i1]
793
- s2l[:,i0]=s2l[:,i1]
1096
+ s2[:, i0] = s2[:, i1]
1097
+ s2l[:, i0] = s2l[:, i1]
794
1098
  else:
795
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
796
- if len(idx[0])>0:
797
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
798
-
799
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
800
- if len(idx[0])>0:
801
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
802
-
803
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
804
- if len(idx[0])>0:
805
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
806
-
807
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
808
- if len(idx[0])>0:
809
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
810
-
1099
+ idx = np.where(
1100
+ (s2[:, i2] > 0) * (s2[:, i3] > 0) * (s2[:, i2] < s2th[:, i2])
1101
+ )
1102
+ if len(idx[0]) > 0:
1103
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1104
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
1105
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
1106
+ )
1107
+
1108
+ idx = np.where(
1109
+ (s2[:, i1] > 0) * (s2[:, i2] > 0) * (s2[:, i1] < s2th[:, i1])
1110
+ )
1111
+ if len(idx[0]) > 0:
1112
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1113
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1114
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1115
+ )
1116
+
1117
+ idx = np.where(
1118
+ (s2l[:, i2] > 0)
1119
+ * (s2l[:, i3] > 0)
1120
+ * (s2l[:, i2] < s2lth[:, i2])
1121
+ )
1122
+ if len(idx[0]) > 0:
1123
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1124
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1125
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1126
+ )
1127
+
1128
+ idx = np.where(
1129
+ (s2l[:, i1] > 0)
1130
+ * (s2l[:, i2] > 0)
1131
+ * (s2l[:, i1] < s2lth[:, i1])
1132
+ )
1133
+ if len(idx[0]) > 0:
1134
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1135
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1136
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1137
+ )
1138
+
811
1139
  if extend:
812
1140
  for k in range(nscale):
813
- for l in range(1,nscale):
814
- i0=np.where((j1==2*nscale-1-k)*(j2==2*nscale-1-k-l))[0]
815
- i1=np.where((j1==2*nscale-1-k)*(j2==2*nscale -k-l))[0]
816
- i2=np.where((j1==2*nscale-1-k)*(j2==2*nscale+1-k-l))[0]
817
- i3=np.where((j1==2*nscale-1-k)*(j2==2*nscale+2-k-l))[0]
1141
+ for l_scale in range(1, nscale):
1142
+ i0 = np.where(
1143
+ (j1 == 2 * nscale - 1 - k)
1144
+ * (j2 == 2 * nscale - 1 - k - l_scale)
1145
+ )[0]
1146
+ i1 = np.where(
1147
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale - k - l_scale)
1148
+ )[0]
1149
+ i2 = np.where(
1150
+ (j1 == 2 * nscale - 1 - k)
1151
+ * (j2 == 2 * nscale + 1 - k - l_scale)
1152
+ )[0]
1153
+ i3 = np.where(
1154
+ (j1 == 2 * nscale - 1 - k)
1155
+ * (j2 == 2 * nscale + 2 - k - l_scale)
1156
+ )[0]
818
1157
  if constant:
819
- s2[:,i0]=s2[:,i1]
820
- s2l[:,i0]=s2l[:,i1]
1158
+ s2[:, i0] = s2[:, i1]
1159
+ s2l[:, i0] = s2l[:, i1]
821
1160
  else:
822
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
823
- if len(idx[0])>0:
824
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
825
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
826
- if len(idx[0])>0:
827
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
828
-
829
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
830
- if len(idx[0])>0:
831
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
832
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
833
- if len(idx[0])>0:
834
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
835
-
836
- s1[np.isnan(s1)]=0.0
837
- p0[np.isnan(p0)]=0.0
838
- s2[np.isnan(s2)]=0.0
839
- s2l[np.isnan(s2l)]=0.0
840
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
841
-
842
- return scat(self.backend.constant(p0),self.S0,
843
- self.backend.constant(s1),
844
- self.backend.constant(s2),
845
- self.backend.constant(s2l),self.j1,self.j2,backend=self.backend)
1161
+ idx = np.where(
1162
+ (s2[:, i2] > 0)
1163
+ * (s2[:, i3] > 0)
1164
+ * (s2[:, i2] < s2th[:, i2])
1165
+ )
1166
+ if len(idx[0]) > 0:
1167
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1168
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
1169
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
1170
+ )
1171
+ idx = np.where(
1172
+ (s2[:, i1] > 0)
1173
+ * (s2[:, i2] > 0)
1174
+ * (s2[:, i1] < s2th[:, i1])
1175
+ )
1176
+ if len(idx[0]) > 0:
1177
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1178
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1179
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1180
+ )
1181
+
1182
+ idx = np.where(
1183
+ (s2l[:, i2] > 0)
1184
+ * (s2l[:, i3] > 0)
1185
+ * (s2l[:, i2] < s2lth[:, i2])
1186
+ )
1187
+ if len(idx[0]) > 0:
1188
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1189
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1190
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1191
+ )
1192
+ idx = np.where(
1193
+ (s2l[:, i1] > 0)
1194
+ * (s2l[:, i2] > 0)
1195
+ * (s2l[:, i1] < s2lth[:, i1])
1196
+ )
1197
+ if len(idx[0]) > 0:
1198
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1199
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1200
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1201
+ )
1202
+
1203
+ s1[np.isnan(s1)] = 0.0
1204
+ p0[np.isnan(p0)] = 0.0
1205
+ s2[np.isnan(s2)] = 0.0
1206
+ s2l[np.isnan(s2l)] = 0.0
1207
+ print(s1.sum(), p0.sum(), s2.sum(), s2l.sum())
1208
+
1209
+ return scat(
1210
+ self.backend.constant(p0),
1211
+ self.S0,
1212
+ self.backend.constant(s1),
1213
+ self.backend.constant(s2),
1214
+ self.backend.constant(s2l),
1215
+ self.j1,
1216
+ self.j2,
1217
+ backend=self.backend,
1218
+ )
846
1219
 
847
1220
  # ---------------------------------------------−---------
848
1221
  def flatten(self):
849
- if isinstance(self.S1,np.ndarray):
850
- return np.concatenate([self.S0.flatten(),
851
- self.S1.flatten(),
852
- self.P00.flatten(),
853
- self.S2.flatten(),
854
- self.S2L.flatten()],0)
1222
+ if isinstance(self.S1, np.ndarray):
1223
+ return np.concatenate(
1224
+ [
1225
+ self.S0.flatten(),
1226
+ self.S1.flatten(),
1227
+ self.P00.flatten(),
1228
+ self.S2.flatten(),
1229
+ self.S2L.flatten(),
1230
+ ],
1231
+ 0,
1232
+ )
855
1233
  else:
856
- return self.backend.bk_concat([self.backend.bk_flattenR(self.S0),
857
- self.backend.bk_flattenR(self.S1),
858
- self.backend.bk_flattenR(self.P00),
859
- self.backend.bk_flattenR(self.S2),
860
- self.backend.bk_flattenR(self.S2)],axis=0)
1234
+ return self.backend.bk_concat(
1235
+ [
1236
+ self.backend.bk_flattenR(self.S0),
1237
+ self.backend.bk_flattenR(self.S1),
1238
+ self.backend.bk_flattenR(self.P00),
1239
+ self.backend.bk_flattenR(self.S2),
1240
+ self.backend.bk_flattenR(self.S2),
1241
+ ],
1242
+ axis=0,
1243
+ )
861
1244
 
862
1245
  # ---------------------------------------------−---------
863
1246
  def flattenMask(self):
864
- if isinstance(self.S1,np.ndarray):
865
- tmp=np.expand_dims(np.concatenate([self.S1[0].flatten(),
866
- self.P00[0].flatten(),
867
- self.S2[0].flatten(),
868
- self.S2L[0].flatten()],0),0)
869
- for k in range(1,self.P00.shape[0]):
870
- tmp=np.concatenate([tmp,np.expand_dims(np.concatenate([self.S1[k].flatten(),
871
- self.P00[k].flatten(),
872
- self.S2[k].flatten(),
873
- self.S2L[k].flatten()],0),0)],0)
874
-
875
-
876
- return np.concatenate([tmp,np.expand_dims(self.S0,1)],1)
1247
+ if isinstance(self.S1, np.ndarray):
1248
+ tmp = np.expand_dims(
1249
+ np.concatenate(
1250
+ [
1251
+ self.S1[0].flatten(),
1252
+ self.P00[0].flatten(),
1253
+ self.S2[0].flatten(),
1254
+ self.S2L[0].flatten(),
1255
+ ],
1256
+ 0,
1257
+ ),
1258
+ 0,
1259
+ )
1260
+ for k in range(1, self.P00.shape[0]):
1261
+ tmp = np.concatenate(
1262
+ [
1263
+ tmp,
1264
+ np.expand_dims(
1265
+ np.concatenate(
1266
+ [
1267
+ self.S1[k].flatten(),
1268
+ self.P00[k].flatten(),
1269
+ self.S2[k].flatten(),
1270
+ self.S2L[k].flatten(),
1271
+ ],
1272
+ 0,
1273
+ ),
1274
+ 0,
1275
+ ),
1276
+ ],
1277
+ 0,
1278
+ )
1279
+
1280
+ return np.concatenate([tmp, np.expand_dims(self.S0, 1)], 1)
877
1281
  else:
878
- tmp=self.backend.bk_expand_dims(self.backend.bk_concat([self.backend.bk_flattenR(self.S1[0]),
879
- self.backend.bk_flattenR(self.P00[0]),
880
- self.backend.bk_flattenR(self.S2[0]),
881
- self.backend.bk_flattenR(self.S2[0])],axis=0),0)
882
- for k in range(1,self.P00.shape[0]):
883
- ltmp=self.backend.bk_expand_dims(self.backend.bk_concat([self.backend.bk_flattenR(self.S1[k]),
884
- self.backend.bk_flattenR(self.P00[k]),
885
- self.backend.bk_flattenR(self.S2[k]),
886
- self.backend.bk_flattenR(self.S2[k])],axis=0),0)
887
- tmp=self.backend.bk_concat([tmp,ltmp],0)
888
-
889
- return self.backend.bk_concat([tmp,self.backend.bk_expand_dims(self.S0,1)],1)
890
-
1282
+ tmp = self.backend.bk_expand_dims(
1283
+ self.backend.bk_concat(
1284
+ [
1285
+ self.backend.bk_flattenR(self.S1[0]),
1286
+ self.backend.bk_flattenR(self.P00[0]),
1287
+ self.backend.bk_flattenR(self.S2[0]),
1288
+ self.backend.bk_flattenR(self.S2[0]),
1289
+ ],
1290
+ axis=0,
1291
+ ),
1292
+ 0,
1293
+ )
1294
+ for k in range(1, self.P00.shape[0]):
1295
+ ltmp = self.backend.bk_expand_dims(
1296
+ self.backend.bk_concat(
1297
+ [
1298
+ self.backend.bk_flattenR(self.S1[k]),
1299
+ self.backend.bk_flattenR(self.P00[k]),
1300
+ self.backend.bk_flattenR(self.S2[k]),
1301
+ self.backend.bk_flattenR(self.S2[k]),
1302
+ ],
1303
+ axis=0,
1304
+ ),
1305
+ 0,
1306
+ )
1307
+ tmp = self.backend.bk_concat([tmp, ltmp], 0)
1308
+
1309
+ return self.backend.bk_concat(
1310
+ [tmp, self.backend.bk_expand_dims(self.S0, 1)], 1
1311
+ )
1312
+
891
1313
  # ---------------------------------------------−---------
892
- def model(self,i__y,add=0,dx=3,dell=2,weigth=None,inverse=False):
1314
+ def model(self, i__y, add=0, dx=3, dell=2, weigth=None, inverse=False):
893
1315
 
894
- if i__y.shape[0]<dx+1:
895
- l__dx=i__y.shape[0]-1
1316
+ if i__y.shape[0] < dx + 1:
1317
+ l__dx = i__y.shape[0] - 1
896
1318
  else:
897
- l__dx=dx
1319
+ l__dx = dx
898
1320
 
899
- if i__y.shape[0]<dell:
900
- l__dell=0
1321
+ if i__y.shape[0] < dell:
1322
+ l__dell = 0
901
1323
  else:
902
- l__dell=dell
1324
+ l__dell = dell
903
1325
 
904
- if l__dx<2:
905
- res=np.zeros([i__y.shape[0]+add])
1326
+ if l__dx < 2:
1327
+ res = np.zeros([i__y.shape[0] + add])
906
1328
  if inverse:
907
- res[:-add]=i__y
1329
+ res[:-add] = i__y
908
1330
  else:
909
- res[add:]=i__y[0:]
1331
+ res[add:] = i__y[0:]
910
1332
  return res
911
1333
 
912
1334
  if weigth is None:
913
- w=2**(np.arange(l__dx))
1335
+ w = 2 ** (np.arange(l__dx))
914
1336
  else:
915
1337
  if not inverse:
916
- w=weigth[0:l__dx]
1338
+ w = weigth[0:l__dx]
917
1339
  else:
918
- w=weigth[-l__dx:]
1340
+ w = weigth[-l__dx:]
919
1341
 
920
- x=np.arange(l__dx)+1
1342
+ x = np.arange(l__dx) + 1
921
1343
  if not inverse:
922
- y=np.log(i__y[1:l__dx+1])
1344
+ y = np.log(i__y[1 : l__dx + 1])
923
1345
  else:
924
- y=np.log(i__y[-(l__dx+1):-1])
1346
+ y = np.log(i__y[-(l__dx + 1) : -1])
925
1347
 
926
- r=np.polyfit(x,y,1,w=w)
1348
+ r = np.polyfit(x, y, 1, w=w)
927
1349
 
928
1350
  if inverse:
929
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-1)+r[1])
930
- res[:-(l__dell+add)]=i__y[:-l__dell]
1351
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - 1) + r[1])
1352
+ res[: -(l__dell + add)] = i__y[:-l__dell]
931
1353
  else:
932
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-add)+r[1])
933
- res[l__dell+add:]=i__y[l__dell:]
1354
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - add) + r[1])
1355
+ res[l__dell + add :] = i__y[l__dell:]
934
1356
  return res
935
1357
 
936
- def findn(self,n):
937
- d=np.sqrt(1+8*n)
938
- return int((d-1)/2)
1358
+ def findn(self, n):
1359
+ d = np.sqrt(1 + 8 * n)
1360
+ return int((d - 1) / 2)
939
1361
 
940
- def findidx(self,s2):
941
- i1=np.zeros([s2.shape[1]],dtype='int')
942
- i2=np.zeros([s2.shape[1]],dtype='int')
943
- n=0
1362
+ def findidx(self, s2):
1363
+ i1 = np.zeros([s2.shape[1]], dtype="int")
1364
+ i2 = np.zeros([s2.shape[1]], dtype="int")
1365
+ n = 0
944
1366
  for k in range(self.findn(s2.shape[1])):
945
- i1[n:n+k+1]=np.arange(k+1)
946
- i2[n:n+k+1]=k
947
- n=n+k+1
948
- return i1,i2
949
-
950
- def extrapol_s2(self,add,lnorm=1):
951
- if lnorm==1:
952
- s2=self.S2.numpy()
953
- if lnorm==2:
954
- s2=self.S2L.numpy()
955
- i1,i2=self.findidx(s2)
956
-
957
- so2=np.zeros([s2.shape[0],(self.findn(s2.shape[1])+add)*(self.findn(s2.shape[1])+add+1)//2,s2.shape[2],s2.shape[3]])
958
- oi1,oi2=self.findidx(so2)
959
- for l in range(s2.shape[0]):
1367
+ i1[n : n + k + 1] = np.arange(k + 1)
1368
+ i2[n : n + k + 1] = k
1369
+ n = n + k + 1
1370
+ return i1, i2
1371
+
1372
+ def extrapol_s2(self, add, lnorm=1):
1373
+ if lnorm == 1:
1374
+ s2 = self.S2.numpy()
1375
+ if lnorm == 2:
1376
+ s2 = self.S2L.numpy()
1377
+ i1, i2 = self.findidx(s2)
1378
+
1379
+ so2 = np.zeros(
1380
+ [
1381
+ s2.shape[0],
1382
+ (self.findn(s2.shape[1]) + add)
1383
+ * (self.findn(s2.shape[1]) + add + 1)
1384
+ // 2,
1385
+ s2.shape[2],
1386
+ s2.shape[3],
1387
+ ]
1388
+ )
1389
+ oi1, oi2 = self.findidx(so2)
1390
+ for l_batch in range(s2.shape[0]):
960
1391
  for k in range(self.findn(s2.shape[1])):
961
1392
  for i in range(s2.shape[2]):
962
1393
  for j in range(s2.shape[3]):
963
- tmp=self.model(s2[l,i2==k,i,j],dx=4,dell=1,add=add,weigth=np.array([1,2,2,2]))
964
- tmp[np.isnan(tmp)]=0.0
965
- so2[l,oi2==k+add,i,j]=tmp
966
-
967
-
968
- for l in range(s2.shape[0]):
969
- for k in range(add+1,-1,-1):
970
- lidx=np.where(oi2-oi1==k)[0]
971
- lidx2=np.where(oi2-oi1==k+1)[0]
1394
+ tmp = self.model(
1395
+ s2[l_batch, i2 == k, i, j],
1396
+ dx=4,
1397
+ dell=1,
1398
+ add=add,
1399
+ weigth=np.array([1, 2, 2, 2]),
1400
+ )
1401
+ tmp[np.isnan(tmp)] = 0.0
1402
+ so2[l_batch, oi2 == k + add, i, j] = tmp
1403
+
1404
+ for l_batch in range(s2.shape[0]):
1405
+ for k in range(add + 1, -1, -1):
1406
+ lidx = np.where(oi2 - oi1 == k)[0]
1407
+ lidx2 = np.where(oi2 - oi1 == k + 1)[0]
972
1408
  for i in range(s2.shape[2]):
973
1409
  for j in range(s2.shape[3]):
974
- so2[l,lidx[0:add+2-k],i,j]=so2[l,lidx2[0:add+2-k],i,j]
1410
+ so2[l_batch, lidx[0 : add + 2 - k], i, j] = so2[
1411
+ l_batch, lidx2[0 : add + 2 - k], i, j
1412
+ ]
975
1413
 
976
- return(so2)
1414
+ return so2
977
1415
 
978
- def extrapol_s1(self,i_s1,add):
979
- s1=i_s1.numpy()
980
- so1=np.zeros([s1.shape[0],s1.shape[1]+add,s1.shape[2]])
1416
+ def extrapol_s1(self, i_s1, add):
1417
+ s1 = i_s1.numpy()
1418
+ so1 = np.zeros([s1.shape[0], s1.shape[1] + add, s1.shape[2]])
981
1419
  for k in range(s1.shape[0]):
982
1420
  for i in range(s1.shape[2]):
983
- so1[k,:,i]=self.model(s1[k,:,i],dx=4,dell=1,add=add)
984
- so1[k,np.isnan(so1[k,:,i]),i]=0.0
1421
+ so1[k, :, i] = self.model(s1[k, :, i], dx=4, dell=1, add=add)
1422
+ so1[k, np.isnan(so1[k, :, i]), i] = 0.0
985
1423
  return so1
986
1424
 
987
- def extrapol(self,add):
988
- return scat(self.extrapol_s1(self.P00,add), \
989
- self.S0, \
990
- self.extrapol_s1(self.S1,add), \
991
- self.extrapol_s2(add,lnorm=1), \
992
- self.extrapol_s2(add,lnorm=2),self.j1,self.j2,backend=self.backend)
993
-
994
-
995
-
996
-
997
-
1425
+ def extrapol(self, add):
1426
+ return scat(
1427
+ self.extrapol_s1(self.P00, add),
1428
+ self.S0,
1429
+ self.extrapol_s1(self.S1, add),
1430
+ self.extrapol_s2(add, lnorm=1),
1431
+ self.extrapol_s2(add, lnorm=2),
1432
+ self.j1,
1433
+ self.j2,
1434
+ backend=self.backend,
1435
+ )
1436
+
1437
+
998
1438
  class funct(FOC.FoCUS):
999
-
1000
- def fill(self,im,nullval=hp.UNSEEN):
1001
- return self.fill_healpy(im,nullval=nullval)
1002
-
1003
- def moments(self,list_scat):
1004
- S0=None
1439
+
1440
+ def fill(self, im, nullval=hp.UNSEEN):
1441
+ return self.fill_healpy(im, nullval=nullval)
1442
+
1443
+ def moments(self, list_scat):
1444
+ S0 = None
1005
1445
  for k in list_scat:
1006
- tmp=list_scat[k]
1007
- nS0=np.expand_dims(tmp.S0.numpy(),0)
1008
- nP00=np.expand_dims(tmp.P00.numpy(),0)
1009
- nS1=np.expand_dims(tmp.S1.numpy(),0)
1010
- nS2=np.expand_dims(tmp.S2.numpy(),0)
1011
- nS2L=np.expand_dims(tmp.S2L.numpy(),0)
1012
-
1446
+ tmp = list_scat[k]
1447
+ nS0 = np.expand_dims(tmp.S0.numpy(), 0)
1448
+ nP00 = np.expand_dims(tmp.P00.numpy(), 0)
1449
+ nS1 = np.expand_dims(tmp.S1.numpy(), 0)
1450
+ nS2 = np.expand_dims(tmp.S2.numpy(), 0)
1451
+ nS2L = np.expand_dims(tmp.S2L.numpy(), 0)
1452
+
1013
1453
  if S0 is None:
1014
- S0=nS0
1015
- P00=nP00
1016
- S1=nS1
1017
- S2=nS2
1018
- S2L=nS2L
1454
+ S0 = nS0
1455
+ P00 = nP00
1456
+ S1 = nS1
1457
+ S2 = nS2
1458
+ S2L = nS2L
1019
1459
  else:
1020
- S0=np.concatenate([S0,nS0],0)
1021
- P00=np.concatenate([P00,nP00],0)
1022
- S1=np.concatenate([S1,nS1],0)
1023
- S2=np.concatenate([S2,nS2],0)
1024
- S2L=np.concatenate([S2L,nS2L],0)
1025
-
1026
- sS0=np.std(S0,0)
1027
- sP00=np.std(P00,0)
1028
- sS1=np.std(S1,0)
1029
- sS2=np.std(S2,0)
1030
- sS2L=np.std(S2L,0)
1031
-
1032
- mS0=np.mean(S0,0)
1033
- mP00=np.mean(P00,0)
1034
- mS1=np.mean(S1,0)
1035
- mS2=np.mean(S2,0)
1036
- mS2L=np.mean(S2L,0)
1037
-
1038
- return scat(mP00,mS0,mS1,mS2,mS2L,tmp.j1,tmp.j2,backend=self.backend), \
1039
- scat(sP00,sS0,sS1,sS2,sS2L,tmp.j1,tmp.j2,backend=self.backend)
1040
-
1041
- def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,calc_var=False,norm=None):
1460
+ S0 = np.concatenate([S0, nS0], 0)
1461
+ P00 = np.concatenate([P00, nP00], 0)
1462
+ S1 = np.concatenate([S1, nS1], 0)
1463
+ S2 = np.concatenate([S2, nS2], 0)
1464
+ S2L = np.concatenate([S2L, nS2L], 0)
1465
+
1466
+ sS0 = np.std(S0, 0)
1467
+ sP00 = np.std(P00, 0)
1468
+ sS1 = np.std(S1, 0)
1469
+ sS2 = np.std(S2, 0)
1470
+ sS2L = np.std(S2L, 0)
1471
+
1472
+ mS0 = np.mean(S0, 0)
1473
+ mP00 = np.mean(P00, 0)
1474
+ mS1 = np.mean(S1, 0)
1475
+ mS2 = np.mean(S2, 0)
1476
+ mS2L = np.mean(S2L, 0)
1477
+
1478
+ return scat(
1479
+ mP00, mS0, mS1, mS2, mS2L, tmp.j1, tmp.j2, backend=self.backend
1480
+ ), scat(sP00, sS0, sS1, sS2, sS2L, tmp.j1, tmp.j2, backend=self.backend)
1481
+
1482
+ def eval(
1483
+ self,
1484
+ image1,
1485
+ image2=None,
1486
+ mask=None,
1487
+ Auto=True,
1488
+ s0_off=1e-6,
1489
+ calc_var=False,
1490
+ norm=None,
1491
+ ):
1042
1492
  # Check input consistency
1043
1493
  if image2 is not None:
1044
- if list(image1.shape)!=list(image2.shape):
1045
- print('The two input image should have the same size to eval Scattering')
1046
-
1494
+ if list(image1.shape) != list(image2.shape):
1495
+ print(
1496
+ "The two input image should have the same size to eval Scattering"
1497
+ )
1498
+
1047
1499
  return None
1048
1500
  if mask is not None:
1049
- if list(image1.shape)!=list(mask.shape)[1:]:
1050
- print('The mask should have the same size than the input image to eval Scattering')
1051
- print('Image shape ',image1.shape,'Mask shape ',mask.shape)
1501
+ if list(image1.shape) != list(mask.shape)[1:]:
1502
+ print(
1503
+ "The mask should have the same size than the input image to eval Scattering"
1504
+ )
1505
+ print("Image shape ", image1.shape, "Mask shape ", mask.shape)
1052
1506
  return None
1053
- if self.use_2D and len(image1.shape)<2:
1054
- print('To work with 2D scattering transform, two dimension is needed, input map has only on dimension')
1507
+ if self.use_2D and len(image1.shape) < 2:
1508
+ print(
1509
+ "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
1510
+ )
1055
1511
  return None
1056
-
1057
-
1512
+
1058
1513
  ### AUTO OR CROSS
1059
1514
  cross = False
1060
1515
  if image2 is not None:
1061
1516
  cross = True
1062
- all_cross=not Auto
1063
- else:
1064
- all_cross=False
1065
-
1517
+
1066
1518
  # Check if image1 is [Npix] or [Nbatch,Npix]
1067
- axis=1
1068
-
1519
+ axis = 1
1520
+
1069
1521
  # determine jmax and nside corresponding to the input map
1070
1522
  im_shape = image1.shape
1071
1523
  if self.use_2D:
1072
- if len(image1.shape)==2:
1073
- nside=np.min([im_shape[0],im_shape[1]])
1074
- npix = im_shape[0]*im_shape[1] # Number of pixels
1075
- x1=im_shape[0]
1076
- x2=im_shape[1]
1524
+ if len(image1.shape) == 2:
1525
+ nside = np.min([im_shape[0], im_shape[1]])
1526
+ npix = im_shape[0] * im_shape[1] # Number of pixels
1077
1527
  else:
1078
- nside=np.min([im_shape[1],im_shape[2]])
1079
- npix = im_shape[1]*im_shape[2] # Number of pixels
1080
- x1=im_shape[1]
1081
- x2=im_shape[2]
1082
- jmax = int(np.log(nside-self.KERNELSZ) / np.log(2)) # Number of j scales
1528
+ nside = np.min([im_shape[1], im_shape[2]])
1529
+ npix = im_shape[1] * im_shape[2] # Number of pixels
1530
+ jmax = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
1083
1531
  else:
1084
- if len(image1.shape)==2:
1532
+ if len(image1.shape) == 2:
1085
1533
  npix = int(im_shape[1]) # Number of pixels
1086
1534
  else:
1087
1535
  npix = int(im_shape[0]) # Number of pixels
1088
1536
 
1089
- nside=int(np.sqrt(npix//12))
1090
-
1091
- jmax=int(np.log(nside)/np.log(2)) #-self.OSTEP
1537
+ nside = int(np.sqrt(npix // 12))
1538
+
1539
+ jmax = int(np.log(nside) / np.log(2)) # -self.OSTEP
1092
1540
 
1093
1541
  ### LOCAL VARIABLES (IMAGES and MASK)
1094
1542
  # Check if image1 is [Npix] or [Nbatch,Npix]
1095
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1543
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1096
1544
  # image1 is [Nbatch, Npix]
1097
- I1 = self.backend.bk_cast(self.backend.bk_expand_dims(image1,0)) # Local image1 [Nbatch, Npix]
1545
+ I1 = self.backend.bk_cast(
1546
+ self.backend.bk_expand_dims(image1, 0)
1547
+ ) # Local image1 [Nbatch, Npix]
1098
1548
  if cross:
1099
- I2 = self.backend.bk_cast(self.backend.bk_expand_dims(image2,0)) # Local image2 [Nbatch, Npix]
1549
+ I2 = self.backend.bk_cast(
1550
+ self.backend.bk_expand_dims(image2, 0)
1551
+ ) # Local image2 [Nbatch, Npix]
1100
1552
  else:
1101
- I1=self.backend.bk_cast(image1)
1553
+ I1 = self.backend.bk_cast(image1)
1102
1554
  if cross:
1103
- I2=self.backend.bk_cast(image2)
1104
-
1555
+ I2 = self.backend.bk_cast(image2)
1556
+
1105
1557
  # self.mask is [Nmask, Npix]
1106
1558
  if mask is None:
1107
1559
  if self.use_2D:
1108
- vmask = self.backend.bk_ones([1, I1.shape[axis], I1.shape[axis+1]],dtype=self.all_type)
1560
+ vmask = self.backend.bk_ones(
1561
+ [1, I1.shape[axis], I1.shape[axis + 1]], dtype=self.all_type
1562
+ )
1109
1563
  else:
1110
1564
  vmask = self.backend.bk_ones([1, I1.shape[axis]], dtype=self.all_type)
1111
1565
  else:
1112
1566
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
1113
1567
 
1114
- if self.KERNELSZ>3:
1115
- if self.KERNELSZ==5:
1568
+ if self.KERNELSZ > 3:
1569
+ if self.KERNELSZ == 5:
1116
1570
  # if the kernel size is bigger than 3 increase the binning before smoothing
1117
1571
  if self.use_2D:
1118
- l_image1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1119
- vmask=self.up_grade(vmask,I1.shape[axis]*2,axis=1,nouty=I1.shape[axis+1]*2)
1572
+ l_image1 = self.up_grade(
1573
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
1574
+ )
1575
+ vmask = self.up_grade(
1576
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
1577
+ )
1120
1578
  else:
1121
- l_image1=self.up_grade(I1,nside*2,axis=axis)
1122
- vmask=self.up_grade(vmask,nside*2,axis=1)
1123
-
1579
+ l_image1 = self.up_grade(I1, nside * 2, axis=axis)
1580
+ vmask = self.up_grade(vmask, nside * 2, axis=1)
1581
+
1124
1582
  if cross:
1125
1583
  if self.use_2D:
1126
- l_image2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1584
+ l_image2 = self.up_grade(
1585
+ I2,
1586
+ I2.shape[axis] * 2,
1587
+ axis=axis,
1588
+ nouty=I2.shape[axis + 1] * 2,
1589
+ )
1127
1590
  else:
1128
- l_image2=self.up_grade(I2,nside*2,axis=axis)
1591
+ l_image2 = self.up_grade(I2, nside * 2, axis=axis)
1129
1592
  else:
1130
1593
  # if the kernel size is bigger than 3 increase the binning before smoothing
1131
1594
  if self.use_2D:
1132
- l_image1=self.up_grade(l_image1,I1.shape[axis]*4,axis=axis,nouty=I1.shape[axis+1]*4)
1133
- vmask=self.up_grade(vmask,I1.shape[axis]*4,axis=1,nouty=I1.shape[axis+1]*4)
1595
+ l_image1 = self.up_grade(
1596
+ l_image1,
1597
+ I1.shape[axis] * 4,
1598
+ axis=axis,
1599
+ nouty=I1.shape[axis + 1] * 4,
1600
+ )
1601
+ vmask = self.up_grade(
1602
+ vmask, I1.shape[axis] * 4, axis=1, nouty=I1.shape[axis + 1] * 4
1603
+ )
1134
1604
  else:
1135
- l_image1=self.up_grade(l_image1,nside*4,axis=axis)
1136
- vmask=self.up_grade(vmask,nside*4,axis=1)
1137
-
1605
+ l_image1 = self.up_grade(l_image1, nside * 4, axis=axis)
1606
+ vmask = self.up_grade(vmask, nside * 4, axis=1)
1607
+
1138
1608
  if cross:
1139
1609
  if self.use_2D:
1140
- l_image2=self.up_grade(l_image2,I2.shape[axis]*4,axis=axis,nouty=I2.shape[axis+1]*4)
1610
+ l_image2 = self.up_grade(
1611
+ l_image2,
1612
+ I2.shape[axis] * 4,
1613
+ axis=axis,
1614
+ nouty=I2.shape[axis + 1] * 4,
1615
+ )
1141
1616
  else:
1142
- l_image2=self.up_grade(l_image2,nside*4,axis=axis)
1617
+ l_image2 = self.up_grade(l_image2, nside * 4, axis=axis)
1143
1618
  else:
1144
- l_image1=I1
1619
+ l_image1 = I1
1145
1620
  if cross:
1146
- l_image2=I2
1621
+ l_image2 = I2
1147
1622
 
1148
1623
  if calc_var:
1149
- s0,vs0 = self.masked_mean(l_image1,vmask,axis=axis,calc_var=True)
1150
- s0=s0+s0_off
1624
+ s0, vs0 = self.masked_mean(l_image1, vmask, axis=axis, calc_var=True)
1625
+ s0 = s0 + s0_off
1151
1626
  else:
1152
- s0 = self.masked_mean(l_image1,vmask,axis=axis)+s0_off
1153
-
1154
- if cross and Auto==False:
1627
+ s0 = self.masked_mean(l_image1, vmask, axis=axis) + s0_off
1628
+
1629
+ if cross and not Auto:
1155
1630
  if calc_var:
1156
- s02,vs02=self.masked_mean(l_image2,vmask,axis=axis,calc_var=True)
1631
+ s02, vs02 = self.masked_mean(l_image2, vmask, axis=axis, calc_var=True)
1157
1632
  else:
1158
- s02=self.masked_mean(l_image2,vmask,axis=axis)
1159
-
1160
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1633
+ s02 = self.masked_mean(l_image2, vmask, axis=axis)
1634
+
1635
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1161
1636
  if self.backend.bk_is_complex(s0):
1162
- s0 = self.backend.bk_complex(s0,s02+s0_off)
1637
+ s0 = self.backend.bk_complex(s0, s02 + s0_off)
1163
1638
  if calc_var:
1164
- vs0 = self.backend.bk_complex(vs0,vs02)
1639
+ vs0 = self.backend.bk_complex(vs0, vs02)
1165
1640
  else:
1166
- s0 = self.backend.bk_concat([s0,s02],axis=0)
1641
+ s0 = self.backend.bk_concat([s0, s02], axis=0)
1167
1642
  if calc_var:
1168
- vs0 = self.backend.bk_concat([vs0,vs02],axis=0)
1643
+ vs0 = self.backend.bk_concat([vs0, vs02], axis=0)
1169
1644
  else:
1170
1645
  if self.backend.bk_is_complex(s0):
1171
- s0 = self.backend.bk_complex(s0,s02+s0_off)
1646
+ s0 = self.backend.bk_complex(s0, s02 + s0_off)
1172
1647
  if calc_var:
1173
- vs0 = self.backend.bk_complex(vs0,vs02)
1648
+ vs0 = self.backend.bk_complex(vs0, vs02)
1174
1649
  else:
1175
- s0 = self.backend.bk_concat([s0,s02],axis=0)
1650
+ s0 = self.backend.bk_concat([s0, s02], axis=0)
1176
1651
  if calc_var:
1177
- vs0 = self.backend.bk_concat([vs0,vs02],axis=0)
1178
-
1179
- s1=None
1180
- s2=None
1181
- s2l=None
1182
- p00=None
1183
- s2j1=None
1184
- s2j2=None
1185
-
1186
- l2_image=None
1187
- l2_image_imag=None
1188
-
1652
+ vs0 = self.backend.bk_concat([vs0, vs02], axis=0)
1653
+
1654
+ s1 = None
1655
+ s2 = None
1656
+ s2l = None
1657
+ p00 = None
1658
+ s2j1 = None
1659
+ s2j2 = None
1660
+ l2_image = None
1189
1661
  for j1 in range(jmax):
1190
- if j1<jmax-self.OSTEP: # stop to add scales
1662
+ if j1 < jmax - self.OSTEP: # stop to add scales
1191
1663
  # Convol image along the axis defined by 'axis' using the wavelet defined at
1192
1664
  # the foscat initialisation
1193
- #c_image_real is [....,Npix_j1,....,Norient]
1194
- c_image1=self.convol(l_image1,axis=axis)
1665
+ # c_image_real is [....,Npix_j1,....,Norient]
1666
+ c_image1 = self.convol(l_image1, axis=axis)
1195
1667
  if cross:
1196
- c_image2=self.convol(l_image2,axis=axis)
1668
+ c_image2 = self.convol(l_image2, axis=axis)
1197
1669
  else:
1198
- c_image2=c_image1
1670
+ c_image2 = c_image1
1199
1671
 
1200
1672
  # Compute (a+ib)*(a+ib)* the last c_image column is the real and imaginary part
1201
- conj=c_image1*self.backend.bk_conjugate(c_image2)
1202
-
1673
+ conj = c_image1 * self.backend.bk_conjugate(c_image2)
1674
+
1203
1675
  if Auto:
1204
- conj=self.backend.bk_real(conj)
1676
+ conj = self.backend.bk_real(conj)
1205
1677
 
1206
1678
  # Compute l_p00 [....,....,Nmask,j1,Norient]
1207
1679
  if calc_var:
1208
- l_p00,l_vp00 = self.masked_mean(conj,vmask,axis=axis,rank=j1,calc_var=True)
1209
- l_p00 = self.backend.bk_expand_dims(l_p00,-2)
1210
- l_vp00 = self.backend.bk_expand_dims(l_vp00,-2)
1680
+ l_p00, l_vp00 = self.masked_mean(
1681
+ conj, vmask, axis=axis, rank=j1, calc_var=True
1682
+ )
1683
+ l_p00 = self.backend.bk_expand_dims(l_p00, -2)
1684
+ l_vp00 = self.backend.bk_expand_dims(l_vp00, -2)
1211
1685
  else:
1212
- l_p00 = self.masked_mean(conj,vmask,axis=axis,rank=j1)
1213
- l_p00 = self.backend.bk_expand_dims(l_p00,-2)
1686
+ l_p00 = self.masked_mean(conj, vmask, axis=axis, rank=j1)
1687
+ l_p00 = self.backend.bk_expand_dims(l_p00, -2)
1214
1688
 
1215
- conj=self.backend.bk_L1(conj)
1689
+ conj = self.backend.bk_L1(conj)
1216
1690
 
1217
- # Compute l_s1 [....,....,Nmask,1,Norient]
1691
+ # Compute l_s1 [....,....,Nmask,1,Norient]
1218
1692
  if calc_var:
1219
- l_s1,l_vs1 = self.masked_mean(conj,vmask,axis=axis,rank=j1,calc_var=True)
1220
- l_s1 =self.backend.bk_expand_dims(l_s1,-2)
1221
- l_vs1 =self.backend.bk_expand_dims(l_vs1,-2)
1693
+ l_s1, l_vs1 = self.masked_mean(
1694
+ conj, vmask, axis=axis, rank=j1, calc_var=True
1695
+ )
1696
+ l_s1 = self.backend.bk_expand_dims(l_s1, -2)
1697
+ l_vs1 = self.backend.bk_expand_dims(l_vs1, -2)
1222
1698
  else:
1223
- l_s1 = self.backend.bk_expand_dims(self.masked_mean(conj,vmask,axis=axis,rank=j1),-2)
1699
+ l_s1 = self.backend.bk_expand_dims(
1700
+ self.masked_mean(conj, vmask, axis=axis, rank=j1), -2
1701
+ )
1224
1702
 
1225
- # Concat S1,P00 [....,....,Nmask,j1,Norient]
1703
+ # Concat S1,P00 [....,....,Nmask,j1,Norient]
1226
1704
  if s1 is None:
1227
- s1=l_s1
1228
- p00=l_p00
1705
+ s1 = l_s1
1706
+ p00 = l_p00
1229
1707
  if calc_var:
1230
- vs1=l_vs1
1231
- vp00=l_vp00
1708
+ vs1 = l_vs1
1709
+ vp00 = l_vp00
1232
1710
  else:
1233
- s1=self.backend.bk_concat([s1,l_s1],axis=-2)
1234
- p00=self.backend.bk_concat([p00,l_p00],axis=-2)
1711
+ s1 = self.backend.bk_concat([s1, l_s1], axis=-2)
1712
+ p00 = self.backend.bk_concat([p00, l_p00], axis=-2)
1235
1713
  if calc_var:
1236
- vs1=self.backend.bk_concat([vs1,l_vs1],axis=-2)
1237
- vp00=self.backend.bk_concat([vp00,l_vp00],axis=-2)
1714
+ vs1 = self.backend.bk_concat([vs1, l_vs1], axis=-2)
1715
+ vp00 = self.backend.bk_concat([vp00, l_vp00], axis=-2)
1238
1716
 
1239
1717
  # Concat l2_image [....,j1,Npix_j1,,....,Norient]
1240
1718
  if l2_image is None:
1241
1719
  if self.use_2D:
1242
- l2_image=self.backend.bk_expand_dims(conj,axis=-4)
1720
+ l2_image = self.backend.bk_expand_dims(conj, axis=-4)
1243
1721
  else:
1244
- l2_image=self.backend.bk_expand_dims(conj,axis=-3)
1722
+ l2_image = self.backend.bk_expand_dims(conj, axis=-3)
1245
1723
  else:
1246
1724
  if self.use_2D:
1247
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=-4),l2_image],axis=-4)
1725
+ l2_image = self.backend.bk_concat(
1726
+ [self.backend.bk_expand_dims(conj, axis=-4), l2_image],
1727
+ axis=-4,
1728
+ )
1248
1729
  else:
1249
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=-3),l2_image],axis=-3)
1730
+ l2_image = self.backend.bk_concat(
1731
+ [self.backend.bk_expand_dims(conj, axis=-3), l2_image],
1732
+ axis=-3,
1733
+ )
1250
1734
 
1251
1735
  # Convol l2_image [....,Npix_j1,j1,....,Norient,Norient]
1252
- c2_image=self.convol(self.backend.bk_relu(l2_image),axis=axis+1)
1736
+ c2_image = self.convol(self.backend.bk_relu(l2_image), axis=axis + 1)
1253
1737
 
1254
- conj2p=c2_image*self.backend.bk_conjugate(c2_image)
1255
- conj2pl1=self.backend.bk_L1(conj2p)
1738
+ conj2p = c2_image * self.backend.bk_conjugate(c2_image)
1739
+ conj2pl1 = self.backend.bk_L1(conj2p)
1256
1740
 
1257
1741
  if Auto:
1258
- conj2p=self.backend.bk_real(conj2p)
1259
- conj2pl1=self.backend.bk_real(conj2pl1)
1742
+ conj2p = self.backend.bk_real(conj2p)
1743
+ conj2pl1 = self.backend.bk_real(conj2pl1)
1260
1744
 
1261
- c2_image=self.convol(self.backend.bk_relu(-l2_image),axis=axis+1)
1745
+ c2_image = self.convol(self.backend.bk_relu(-l2_image), axis=axis + 1)
1262
1746
 
1263
- conj2m=c2_image*self.backend.bk_conjugate(c2_image)
1264
- conj2ml1=self.backend.bk_L1(conj2m)
1747
+ conj2m = c2_image * self.backend.bk_conjugate(c2_image)
1748
+ conj2ml1 = self.backend.bk_L1(conj2m)
1265
1749
 
1266
1750
  if Auto:
1267
- conj2m=self.backend.bk_real(conj2m)
1268
- conj2ml1=self.backend.bk_real(conj2ml1)
1269
-
1751
+ conj2m = self.backend.bk_real(conj2m)
1752
+ conj2ml1 = self.backend.bk_real(conj2ml1)
1753
+
1270
1754
  # Convol l_s2 [....,....,Nmask,j1,Norient,Norient]
1271
1755
  if calc_var:
1272
- l_s2,l_vs2 = self.masked_mean(conj2p-conj2m,vmask,axis=axis+1,rank=j1,calc_var=True)
1273
- l_s2l1,l_vs2l1 = self.masked_mean(conj2pl1-conj2ml1,vmask,axis=axis+1,rank=j1,calc_var=True)
1756
+ l_s2, l_vs2 = self.masked_mean(
1757
+ conj2p - conj2m, vmask, axis=axis + 1, rank=j1, calc_var=True
1758
+ )
1759
+ l_s2l1, l_vs2l1 = self.masked_mean(
1760
+ conj2pl1 - conj2ml1, vmask, axis=axis + 1, rank=j1, calc_var=True
1761
+ )
1274
1762
  else:
1275
- l_s2 = self.masked_mean(conj2p-conj2m,vmask,axis=axis+1,rank=j1)
1276
- l_s2l1 = self.masked_mean(conj2pl1-conj2ml1,vmask,axis=axis+1,rank=j1)
1763
+ l_s2 = self.masked_mean(conj2p - conj2m, vmask, axis=axis + 1, rank=j1)
1764
+ l_s2l1 = self.masked_mean(
1765
+ conj2pl1 - conj2ml1, vmask, axis=axis + 1, rank=j1
1766
+ )
1277
1767
 
1278
1768
  # Concat l_s2 [....,....,Nmask,j1*(j1+1)/2,Norient,Norient]
1279
1769
  if s2 is None:
1280
- s2l=l_s2
1281
- s2=l_s2l1
1770
+ s2l = l_s2
1771
+ s2 = l_s2l1
1282
1772
  if calc_var:
1283
- vs2l=l_vs2
1284
- vs2=l_vs2l1
1285
-
1286
- s2j1=np.arange(l_s2.shape[axis+1],dtype='int')
1287
- s2j2=j1*np.ones(l_s2.shape[axis+1],dtype='int')
1773
+ vs2l = l_vs2
1774
+ vs2 = l_vs2l1
1775
+
1776
+ s2j1 = np.arange(l_s2.shape[axis + 1], dtype="int")
1777
+ s2j2 = j1 * np.ones(l_s2.shape[axis + 1], dtype="int")
1288
1778
  else:
1289
- s2=self.backend.bk_concat([s2,l_s2l1],axis=-3)
1290
- s2l=self.backend.bk_concat([s2l,l_s2],axis=-3)
1779
+ s2 = self.backend.bk_concat([s2, l_s2l1], axis=-3)
1780
+ s2l = self.backend.bk_concat([s2l, l_s2], axis=-3)
1291
1781
  if calc_var:
1292
- vs2=self.backend.bk_concat([vs2,l_vs2l1],axis=-3)
1293
- vs2l=self.backend.bk_concat([vs2l,l_vs2],axis=-3)
1294
-
1295
- s2j1=np.concatenate([s2j1,np.arange(l_s2.shape[axis+1],dtype='int')],0)
1296
- s2j2=np.concatenate([s2j2,j1*np.ones(l_s2.shape[axis+1],dtype='int')],0)
1297
-
1298
- if j1!=jmax-1:
1299
- # Rescale vmask [Nmask,Npix_j1//4]
1300
- vmask = self.smooth(vmask,axis=1)
1301
- vmask = self.ud_grade_2(vmask,axis=1)
1782
+ vs2 = self.backend.bk_concat([vs2, l_vs2l1], axis=-3)
1783
+ vs2l = self.backend.bk_concat([vs2l, l_vs2], axis=-3)
1784
+
1785
+ s2j1 = np.concatenate(
1786
+ [s2j1, np.arange(l_s2.shape[axis + 1], dtype="int")], 0
1787
+ )
1788
+ s2j2 = np.concatenate(
1789
+ [s2j2, j1 * np.ones(l_s2.shape[axis + 1], dtype="int")], 0
1790
+ )
1791
+
1792
+ if j1 != jmax - 1:
1793
+ # Rescale vmask [Nmask,Npix_j1//4]
1794
+ vmask = self.smooth(vmask, axis=1)
1795
+ vmask = self.ud_grade_2(vmask, axis=1)
1302
1796
  if self.mask_thres is not None:
1303
- vmask = self.backend.bk_threshold(vmask,self.mask_thres)
1797
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
1304
1798
 
1305
- # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1306
- l2_image = self.smooth(l2_image,axis=axis+1)
1307
- l2_image = self.ud_grade_2(l2_image,axis=axis+1)
1799
+ # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1800
+ l2_image = self.smooth(l2_image, axis=axis + 1)
1801
+ l2_image = self.ud_grade_2(l2_image, axis=axis + 1)
1308
1802
 
1309
- # Rescale l_image [....,Npix_j1//4,....]
1310
- l_image1 = self.smooth(l_image1,axis=axis)
1311
- l_image1 = self.ud_grade_2(l_image1,axis=axis)
1803
+ # Rescale l_image [....,Npix_j1//4,....]
1804
+ l_image1 = self.smooth(l_image1, axis=axis)
1805
+ l_image1 = self.ud_grade_2(l_image1, axis=axis)
1312
1806
  if cross:
1313
- l_image2 = self.smooth(l_image2,axis=axis)
1314
- l_image2 = self.ud_grade_2(l_image2,axis=axis)
1315
-
1316
-
1317
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1318
- sc_ret=scat(p00[0],s0[0],s1[0],s2[0],s2l[0],s2j1,s2j2,cross=cross,backend=self.backend)
1807
+ l_image2 = self.smooth(l_image2, axis=axis)
1808
+ l_image2 = self.ud_grade_2(l_image2, axis=axis)
1809
+
1810
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1811
+ sc_ret = scat(
1812
+ p00[0],
1813
+ s0[0],
1814
+ s1[0],
1815
+ s2[0],
1816
+ s2l[0],
1817
+ s2j1,
1818
+ s2j2,
1819
+ cross=cross,
1820
+ backend=self.backend,
1821
+ )
1319
1822
  else:
1320
- sc_ret=scat(p00,s0,s1,s2,s2l,s2j1,s2j2,cross=cross,backend=self.backend)
1321
-
1823
+ sc_ret = scat(
1824
+ p00, s0, s1, s2, s2l, s2j1, s2j2, cross=cross, backend=self.backend
1825
+ )
1826
+
1322
1827
  if calc_var:
1323
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1324
- vsc_ret=scat(vp00[0],vs0[0],vs1[0],vs2[0],vs2l[0],s2j1,s2j2,cross=cross,backend=self.backend)
1828
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1829
+ vsc_ret = scat(
1830
+ vp00[0],
1831
+ vs0[0],
1832
+ vs1[0],
1833
+ vs2[0],
1834
+ vs2l[0],
1835
+ s2j1,
1836
+ s2j2,
1837
+ cross=cross,
1838
+ backend=self.backend,
1839
+ )
1325
1840
  else:
1326
- vsc_ret=scat(vp00,vs0,vs1,vs2,vs2l,s2j1,s2j2,cross=cross,backend=self.backend)
1327
- return sc_ret,vsc_ret
1841
+ vsc_ret = scat(
1842
+ vp00,
1843
+ vs0,
1844
+ vs1,
1845
+ vs2,
1846
+ vs2l,
1847
+ s2j1,
1848
+ s2j2,
1849
+ cross=cross,
1850
+ backend=self.backend,
1851
+ )
1852
+ return sc_ret, vsc_ret
1328
1853
  else:
1329
1854
  return sc_ret
1330
1855
 
1331
- def square(self,x):
1856
+ def square(self, x):
1332
1857
  # the abs make the complex value usable for reduce_sum or mean
1333
- return scat(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1334
- self.backend.bk_square(self.backend.bk_abs(x.S0)),
1335
- self.backend.bk_square(self.backend.bk_abs(x.S1)),
1336
- self.backend.bk_square(self.backend.bk_abs(x.S2)),
1337
- self.backend.bk_square(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1338
-
1339
- def sqrt(self,x):
1858
+ return scat(
1859
+ self.backend.bk_square(self.backend.bk_abs(x.P00)),
1860
+ self.backend.bk_square(self.backend.bk_abs(x.S0)),
1861
+ self.backend.bk_square(self.backend.bk_abs(x.S1)),
1862
+ self.backend.bk_square(self.backend.bk_abs(x.S2)),
1863
+ self.backend.bk_square(self.backend.bk_abs(x.S2L)),
1864
+ x.j1,
1865
+ x.j2,
1866
+ backend=self.backend,
1867
+ )
1868
+
1869
+ def sqrt(self, x):
1340
1870
  # the abs make the complex value usable for reduce_sum or mean
1341
- return scat(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1342
- self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1343
- self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1344
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1345
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1346
-
1347
- def reduce_distance(self, x,y, sigma=None):
1348
-
1871
+ return scat(
1872
+ self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1873
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1874
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1875
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1876
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),
1877
+ x.j1,
1878
+ x.j2,
1879
+ backend=self.backend,
1880
+ )
1881
+
1882
+ def reduce_distance(self, x, y, sigma=None):
1883
+
1349
1884
  if isinstance(x, scat):
1350
1885
  if sigma is None:
1351
- result=self.diff_data(y.S0,x.S0,is_complex=False)
1352
- result+=self.diff_data(y.S1,x.S1)
1353
- result+=self.diff_data(y.P00,x.P00)
1354
- result+=self.diff_data(y.S2,x.S2)
1355
- result+=self.diff_data(y.S2L,x.S2L)
1886
+ result = self.diff_data(y.S0, x.S0, is_complex=False)
1887
+ result += self.diff_data(y.S1, x.S1)
1888
+ result += self.diff_data(y.P00, x.P00)
1889
+ result += self.diff_data(y.S2, x.S2)
1890
+ result += self.diff_data(y.S2L, x.S2L)
1356
1891
  else:
1357
- result=self.diff_data(y.S0,x.S0,is_complex=False,sigma=sigma.S0)
1358
- result+=self.diff_data(y.S1,x.S1,sigma=sigma.S1)
1359
- result+=self.diff_data(y.P00,x.P00,sigma=sigma.P00)
1360
- result+=self.diff_data(y.S2,x.S2,sigma=sigma.S2)
1361
- result+=self.diff_data(y.S2L,x.S2L,sigma=sigma.S2L)
1362
-
1363
- nval=self.backend.bk_size(x.S0)+self.backend.bk_size(x.P00)+ \
1364
- self.backend.bk_size(x.S1)+self.backend.bk_size(x.S2)+self.backend.bk_size(x.S2L)
1365
-
1366
- result/=self.backend.bk_cast(nval)
1892
+ result = self.diff_data(y.S0, x.S0, is_complex=False, sigma=sigma.S0)
1893
+ result += self.diff_data(y.S1, x.S1, sigma=sigma.S1)
1894
+ result += self.diff_data(y.P00, x.P00, sigma=sigma.P00)
1895
+ result += self.diff_data(y.S2, x.S2, sigma=sigma.S2)
1896
+ result += self.diff_data(y.S2L, x.S2L, sigma=sigma.S2L)
1897
+
1898
+ nval = (
1899
+ self.backend.bk_size(x.S0)
1900
+ + self.backend.bk_size(x.P00)
1901
+ + self.backend.bk_size(x.S1)
1902
+ + self.backend.bk_size(x.S2)
1903
+ + self.backend.bk_size(x.S2L)
1904
+ )
1905
+
1906
+ result /= self.backend.bk_cast(nval)
1367
1907
  else:
1368
1908
  return self.backend.bk_reduce_sum(x)
1369
1909
  return result
1370
1910
 
1371
- def reduce_mean(self,x,axis=None):
1911
+ def reduce_mean(self, x, axis=None):
1372
1912
  if axis is None:
1373
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))+ \
1374
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))+ \
1375
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))+ \
1376
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))+ \
1377
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1378
-
1379
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1380
- np.array(list(x.S0.shape)).prod()+ \
1381
- np.array(list(x.S1.shape)).prod()+ \
1382
- np.array(list(x.S2.shape)).prod()
1383
-
1384
- return tmp/ntmp
1913
+ tmp = (
1914
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))
1915
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))
1916
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))
1917
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))
1918
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1919
+ )
1920
+
1921
+ ntmp = (
1922
+ np.array(list(x.P00.shape)).prod()
1923
+ + np.array(list(x.S0.shape)).prod()
1924
+ + np.array(list(x.S1.shape)).prod()
1925
+ + np.array(list(x.S2.shape)).prod()
1926
+ )
1927
+
1928
+ return tmp / ntmp
1385
1929
  else:
1386
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00,axis=axis))+ \
1387
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0,axis=axis))+ \
1388
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1,axis=axis))+ \
1389
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2,axis=axis))+ \
1390
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L,axis=axis))
1391
-
1392
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1393
- np.array(list(x.S0.shape)).prod()+ \
1394
- np.array(list(x.S1.shape)).prod()+ \
1395
- np.array(list(x.S2.shape)).prod()+ \
1396
- np.array(list(x.S2L.shape)).prod()
1397
-
1398
- return tmp/ntmp
1399
-
1400
- def reduce_sum(self,x,axis=None):
1930
+ tmp = (
1931
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00, axis=axis))
1932
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0, axis=axis))
1933
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1, axis=axis))
1934
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2, axis=axis))
1935
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L, axis=axis))
1936
+ )
1937
+
1938
+ ntmp = (
1939
+ np.array(list(x.P00.shape)).prod()
1940
+ + np.array(list(x.S0.shape)).prod()
1941
+ + np.array(list(x.S1.shape)).prod()
1942
+ + np.array(list(x.S2.shape)).prod()
1943
+ + np.array(list(x.S2L.shape)).prod()
1944
+ )
1945
+
1946
+ return tmp / ntmp
1947
+
1948
+ def reduce_sum(self, x, axis=None):
1401
1949
  if axis is None:
1402
- return self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))+ \
1403
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))+ \
1404
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))+ \
1405
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))+ \
1406
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1950
+ return (
1951
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))
1952
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
1953
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
1954
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
1955
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1956
+ )
1407
1957
  else:
1408
- return scat(self.backend.bk_reduce_sum(x.P00,axis=axis),
1409
- self.backend.bk_reduce_sum(x.S0,axis=axis),
1410
- self.backend.bk_reduce_sum(x.S1,axis=axis),
1411
- self.backend.bk_reduce_sum(x.S2,axis=axis),
1412
- self.backend.bk_reduce_sum(x.S2L,axis=axis),x.j1,x.j2,backend=self.backend)
1413
-
1414
- def ldiff(self,sig,x):
1415
- return scat(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1416
- x.domult(sig.S0,x.S0)*x.domult(sig.S0,x.S0),
1417
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1418
- x.domult(sig.S2,x.S2)*x.domult(sig.S2,x.S2),
1419
- x.domult(sig.S2L,x.S2L)*x.domult(sig.S2L,x.S2L),x.j1,x.j2,backend=self.backend)
1420
-
1421
- def log(self,x):
1422
- return scat(self.backend.bk_log(x.P00),
1423
- self.backend.bk_log(x.S0),
1424
- self.backend.bk_log(x.S1),
1425
- self.backend.bk_log(x.S2),
1426
- self.backend.bk_log(x.S2L),x.j1,x.j2,backend=self.backend)
1427
- def abs(self,x):
1428
- return scat(self.backend.bk_abs(x.P00),
1429
- self.backend.bk_abs(x.S0),
1430
- self.backend.bk_abs(x.S1),
1431
- self.backend.bk_abs(x.S2),
1432
- self.backend.bk_abs(x.S2L),x.j1,x.j2,backend=self.backend)
1433
- def inv(self,x):
1434
- return scat(1/(x.P00),1/(x.S0),1/(x.S1),1/(x.S2),1/(x.S2L),x.j1,x.j2,backend=self.backend)
1958
+ return scat(
1959
+ self.backend.bk_reduce_sum(x.P00, axis=axis),
1960
+ self.backend.bk_reduce_sum(x.S0, axis=axis),
1961
+ self.backend.bk_reduce_sum(x.S1, axis=axis),
1962
+ self.backend.bk_reduce_sum(x.S2, axis=axis),
1963
+ self.backend.bk_reduce_sum(x.S2L, axis=axis),
1964
+ x.j1,
1965
+ x.j2,
1966
+ backend=self.backend,
1967
+ )
1968
+
1969
+ def ldiff(self, sig, x):
1970
+ return scat(
1971
+ x.domult(sig.P00, x.P00) * x.domult(sig.P00, x.P00),
1972
+ x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
1973
+ x.domult(sig.S1, x.S1) * x.domult(sig.S1, x.S1),
1974
+ x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
1975
+ x.domult(sig.S2L, x.S2L) * x.domult(sig.S2L, x.S2L),
1976
+ x.j1,
1977
+ x.j2,
1978
+ backend=self.backend,
1979
+ )
1980
+
1981
+ def log(self, x):
1982
+ return scat(
1983
+ self.backend.bk_log(x.P00),
1984
+ self.backend.bk_log(x.S0),
1985
+ self.backend.bk_log(x.S1),
1986
+ self.backend.bk_log(x.S2),
1987
+ self.backend.bk_log(x.S2L),
1988
+ x.j1,
1989
+ x.j2,
1990
+ backend=self.backend,
1991
+ )
1992
+
1993
+ def abs(self, x):
1994
+ return scat(
1995
+ self.backend.bk_abs(x.P00),
1996
+ self.backend.bk_abs(x.S0),
1997
+ self.backend.bk_abs(x.S1),
1998
+ self.backend.bk_abs(x.S2),
1999
+ self.backend.bk_abs(x.S2L),
2000
+ x.j1,
2001
+ x.j2,
2002
+ backend=self.backend,
2003
+ )
2004
+
2005
+ def inv(self, x):
2006
+ return scat(
2007
+ 1 / (x.P00),
2008
+ 1 / (x.S0),
2009
+ 1 / (x.S1),
2010
+ 1 / (x.S2),
2011
+ 1 / (x.S2L),
2012
+ x.j1,
2013
+ x.j2,
2014
+ backend=self.backend,
2015
+ )
1435
2016
 
1436
2017
  def one(self):
1437
- return scat(1.0,1.0,1.0,1.0,1.0,[0],[0],backend=self.backend)
2018
+ return scat(1.0, 1.0, 1.0, 1.0, 1.0, [0], [0], backend=self.backend)
1438
2019
 
1439
2020
  @tf_function
1440
- def eval_comp_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
2021
+ def eval_comp_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1441
2022
 
1442
- res=self.eval(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1443
- return res.P00,res.S0,res.S1,res.S2,res.S2L,res.j1,res.j2
2023
+ res = self.eval(image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off)
2024
+ return res.P00, res.S0, res.S1, res.S2, res.S2L, res.j1, res.j2
1444
2025
 
1445
- def eval_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1446
- p0,s0,s1,s2,s2l,j1,j2=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1447
- return scat(p0,s0,s1,s2,s2l,j1,j2,backend=self.backend)
1448
-
1449
-
1450
-
2026
+ def eval_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
2027
+ p0, s0, s1, s2, s2l, j1, j2 = self.eval_comp_fast(
2028
+ image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off
2029
+ )
2030
+ return scat(p0, s0, s1, s2, s2l, j1, j2, backend=self.backend)