foscat 3.1.5__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat.py CHANGED
@@ -1,433 +1,602 @@
1
- import foscat.FoCUS as FOC
2
- import numpy as np
3
1
  import pickle
4
- import foscat.backend as bk
5
- import healpy as hp
6
2
  import sys
7
-
3
+
4
+ import healpy as hp
5
+ import numpy as np
6
+
7
+ import foscat.backend as bk
8
+ import foscat.FoCUS as FOC
9
+
8
10
  # Vérifier si TensorFlow est importé et défini
9
- tf_defined = 'tensorflow' in sys.modules
11
+ tf_defined = "tensorflow" in sys.modules
10
12
 
11
13
  if tf_defined:
12
14
  import tensorflow as tf
13
- tf_function = tf.function # Facultatif : si vous voulez utiliser TensorFlow dans ce script
15
+
16
+ tf_function = (
17
+ tf.function
18
+ ) # Facultatif : si vous voulez utiliser TensorFlow dans ce script
14
19
  else:
20
+
15
21
  def tf_function(func):
16
22
  return func
17
-
23
+
24
+
18
25
  def read(filename):
19
- thescat=scat(1,1,1,1,1,[0],[0])
26
+ thescat = scat(1, 1, 1, 1, 1, [0], [0])
20
27
  return thescat.read(filename)
21
-
28
+
29
+
22
30
  class scat:
23
- def __init__(self,p00,s0,s1,s2,s2l,j1,j2,cross=False,backend=None):
24
- self.bk_type='SCAT'
25
- self.P00=p00
26
- self.S0=s0
27
- self.S1=s1
28
- self.S2=s2
29
- self.S2L=s2l
30
- self.j1=j1
31
- self.j2=j2
32
- self.cross=cross
33
- self.backend=backend
34
-
35
- def set_bk_type(self,bk_type):
36
- self.bk_type=bk_type
37
-
31
+ def __init__(self, p00, s0, s1, s2, s2l, j1, j2, cross=False, backend=None):
32
+ self.bk_type = "SCAT"
33
+ self.P00 = p00
34
+ self.S0 = s0
35
+ self.S1 = s1
36
+ self.S2 = s2
37
+ self.S2L = s2l
38
+ self.j1 = j1
39
+ self.j2 = j2
40
+ self.cross = cross
41
+ self.backend = backend
42
+
43
+ def set_bk_type(self, bk_type):
44
+ self.bk_type = bk_type
45
+
38
46
  def get_j_idx(self):
39
- return self.j1,self.j2
40
-
47
+ return self.j1, self.j2
48
+
41
49
  def get_S0(self):
42
- return(self.S0)
50
+ return self.S0
43
51
 
44
52
  def get_S1(self):
45
- return(self.S1)
46
-
53
+ return self.S1
54
+
47
55
  def get_S2(self):
48
- return(self.S2)
56
+ return self.S2
49
57
 
50
58
  def get_S2L(self):
51
- return(self.S2L)
59
+ return self.S2L
52
60
 
53
61
  def get_P00(self):
54
- return(self.P00)
62
+ return self.P00
55
63
 
56
64
  def reset_P00(self):
57
- self.P00=0*self.P00
65
+ self.P00 = 0 * self.P00
58
66
 
59
67
  def constant(self):
60
- return scat(self.backend.constant(self.P00 ), \
61
- self.backend.constant(self.S0 ), \
62
- self.backend.constant(self.S1 ), \
63
- self.backend.constant(self.S2 ), \
64
- self.backend.constant(self.S2L ), \
65
- self.j1 , \
66
- self.j2 ,backend=self.backend)
67
-
68
-
69
- def domult(self,x,y):
68
+ return scat(
69
+ self.backend.constant(self.P00),
70
+ self.backend.constant(self.S0),
71
+ self.backend.constant(self.S1),
72
+ self.backend.constant(self.S2),
73
+ self.backend.constant(self.S2L),
74
+ self.j1,
75
+ self.j2,
76
+ backend=self.backend,
77
+ )
78
+
79
+ def domult(self, x, y):
70
80
  try:
71
- return x*y
81
+ return x * y
72
82
  except:
73
- if x.dtype==y.dtype:
74
- return x*y
83
+ if x.dtype == y.dtype:
84
+ return x * y
75
85
  if self.backend.bk_is_complex(x):
76
86
 
77
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
87
+ return self.backend.bk_complex(
88
+ self.backend.bk_real(x) * y, self.backend.bk_imag(x) * y
89
+ )
78
90
  else:
79
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
91
+ return self.backend.bk_complex(
92
+ self.backend.bk_real(y) * x, self.backend.bk_imag(y) * x
93
+ )
80
94
 
81
- def dodiv(self,x,y):
95
+ def dodiv(self, x, y):
82
96
  try:
83
- return x/y
97
+ return x / y
84
98
  except:
85
- if x.dtype==y.dtype:
86
- return x/y
99
+ if x.dtype == y.dtype:
100
+ return x / y
87
101
  if self.backend.bk_is_complex(x):
88
-
89
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
102
+
103
+ return self.backend.bk_complex(
104
+ self.backend.bk_real(x) / y, self.backend.bk_imag(x) / y
105
+ )
90
106
  else:
91
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
92
-
93
- def domin(self,x,y):
107
+ return self.backend.bk_complex(
108
+ x / self.backend.bk_real(y), x / self.backend.bk_imag(y)
109
+ )
110
+
111
+ def domin(self, x, y):
94
112
  try:
95
- return x-y
113
+ return x - y
96
114
  except:
97
- if x.dtype==y.dtype:
98
- return x-y
115
+ if x.dtype == y.dtype:
116
+ return x - y
99
117
 
100
118
  if self.backend.bk_is_complex(x):
101
119
 
102
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
120
+ return self.backend.bk_complex(
121
+ self.backend.bk_real(x) - y, self.backend.bk_imag(x) - y
122
+ )
103
123
  else:
104
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
105
-
106
- def doadd(self,x,y):
124
+ return self.backend.bk_complex(
125
+ x - self.backend.bk_real(y), x - self.backend.bk_imag(y)
126
+ )
127
+
128
+ def doadd(self, x, y):
107
129
  try:
108
- return x+y
130
+ return x + y
109
131
  except:
110
- if x.dtype==y.dtype:
111
- return x+y
132
+ if x.dtype == y.dtype:
133
+ return x + y
112
134
  if self.backend.bk_is_complex(x):
113
135
 
114
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
136
+ return self.backend.bk_complex(
137
+ self.backend.bk_real(x) + y, self.backend.bk_imag(x) + y
138
+ )
115
139
  else:
116
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
117
-
118
- def relu(self):
119
-
120
- return scat(self.backend.bk_relu(self.P00), \
121
- self.backend.bk_relu(self.S0), \
122
- self.backend.bk_relu(self.S1), \
123
- self.backend.bk_relu(self.S2), \
124
- self.backend.bk_relu(self.S2L), \
125
- self.j1,self.j2,backend=self.backend)
126
-
127
- def __add__(self,other):
128
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
129
- isinstance(other, bool) or isinstance(other, scat)
130
-
140
+ return self.backend.bk_complex(
141
+ x + self.backend.bk_real(y), x + self.backend.bk_imag(y)
142
+ )
143
+
144
+ def __add__(self, other):
145
+ assert (
146
+ isinstance(other, float)
147
+ or isinstance(other, np.float32)
148
+ or isinstance(other, int)
149
+ or isinstance(other, bool)
150
+ or isinstance(other, scat)
151
+ )
152
+
131
153
  if isinstance(other, scat):
132
- return scat(self.doadd(self.P00,other.P00), \
133
- self.doadd(self.S0, other.S0), \
134
- self.doadd(self.S1, other.S1), \
135
- self.doadd(self.S2, other.S2), \
136
- self.doadd(self.S2L, other.S2L), \
137
- self.j1,self.j2,backend=self.backend)
154
+ return scat(
155
+ self.doadd(self.P00, other.P00),
156
+ self.doadd(self.S0, other.S0),
157
+ self.doadd(self.S1, other.S1),
158
+ self.doadd(self.S2, other.S2),
159
+ self.doadd(self.S2L, other.S2L),
160
+ self.j1,
161
+ self.j2,
162
+ backend=self.backend,
163
+ )
138
164
  else:
139
- return scat((self.P00+ other), \
140
- (self.S0+ other), \
141
- (self.S1+ other), \
142
- (self.S2+ other), \
143
- (self.S2L+ other), \
144
- self.j1,self.j2,backend=self.backend)
145
-
146
- def toreal(self,value):
165
+ return scat(
166
+ (self.P00 + other),
167
+ (self.S0 + other),
168
+ (self.S1 + other),
169
+ (self.S2 + other),
170
+ (self.S2L + other),
171
+ self.j1,
172
+ self.j2,
173
+ backend=self.backend,
174
+ )
175
+
176
+ def toreal(self, value):
147
177
  if value is None:
148
178
  return None
149
-
179
+
150
180
  return self.backend.bk_real(value)
151
181
 
152
- def addcomplex(self,value,amp):
182
+ def addcomplex(self, value, amp):
153
183
  if value is None:
154
184
  return None
155
-
156
- return self.backend.bk_complex(value,amp*value)
157
-
158
- def add_complex(self,amp):
159
- return scat(self.addcomplex(self.P00,amp), \
160
- self.addcomplex(self.S0,amp), \
161
- self.addcomplex(self.S1,amp), \
162
- self.addcomplex(self.S2,amp), \
163
- self.addcomplex(self.S2L,amp), \
164
- self.j1,self.j2,backend=self.backend)
165
-
185
+
186
+ return self.backend.bk_complex(value, amp * value)
187
+
188
+ def add_complex(self, amp):
189
+ return scat(
190
+ self.addcomplex(self.P00, amp),
191
+ self.addcomplex(self.S0, amp),
192
+ self.addcomplex(self.S1, amp),
193
+ self.addcomplex(self.S2, amp),
194
+ self.addcomplex(self.S2L, amp),
195
+ self.j1,
196
+ self.j2,
197
+ backend=self.backend,
198
+ )
199
+
166
200
  def real(self):
167
- return scat(self.toreal(self.P00), \
168
- self.toreal(self.S0), \
169
- self.toreal(self.S1), \
170
- self.toreal(self.S2), \
171
- self.toreal(self.S2L), \
172
- self.j1,self.j2,backend=self.backend)
173
-
174
- def __radd__(self,other):
201
+ return scat(
202
+ self.toreal(self.P00),
203
+ self.toreal(self.S0),
204
+ self.toreal(self.S1),
205
+ self.toreal(self.S2),
206
+ self.toreal(self.S2L),
207
+ self.j1,
208
+ self.j2,
209
+ backend=self.backend,
210
+ )
211
+
212
+ def __radd__(self, other):
175
213
  return self.__add__(other)
176
214
 
177
- def __truediv__(self,other):
178
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
179
- isinstance(other, bool) or isinstance(other, scat)
180
-
215
+ def __truediv__(self, other):
216
+ assert (
217
+ isinstance(other, float)
218
+ or isinstance(other, np.float32)
219
+ or isinstance(other, int)
220
+ or isinstance(other, bool)
221
+ or isinstance(other, scat)
222
+ )
223
+
181
224
  if isinstance(other, scat):
182
- return scat(self.dodiv(self.P00, other.P00), \
183
- self.dodiv(self.S0, other.S0), \
184
- self.dodiv(self.S1, other.S1), \
185
- self.dodiv(self.S2, other.S2), \
186
- self.dodiv(self.S2L, other.S2L), \
187
- self.j1,self.j2,backend=self.backend)
225
+ return scat(
226
+ self.dodiv(self.P00, other.P00),
227
+ self.dodiv(self.S0, other.S0),
228
+ self.dodiv(self.S1, other.S1),
229
+ self.dodiv(self.S2, other.S2),
230
+ self.dodiv(self.S2L, other.S2L),
231
+ self.j1,
232
+ self.j2,
233
+ backend=self.backend,
234
+ )
188
235
  else:
189
- return scat((self.P00/ other), \
190
- (self.S0/ other), \
191
- (self.S1/ other), \
192
- (self.S2/ other), \
193
- (self.S2L/ other), \
194
- self.j1,self.j2,backend=self.backend)
195
-
196
-
197
- def __rtruediv__(self,other):
198
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
199
- isinstance(other, bool) or isinstance(other, scat)
200
-
236
+ return scat(
237
+ (self.P00 / other),
238
+ (self.S0 / other),
239
+ (self.S1 / other),
240
+ (self.S2 / other),
241
+ (self.S2L / other),
242
+ self.j1,
243
+ self.j2,
244
+ backend=self.backend,
245
+ )
246
+
247
+ def __rtruediv__(self, other):
248
+ assert (
249
+ isinstance(other, float)
250
+ or isinstance(other, np.float32)
251
+ or isinstance(other, int)
252
+ or isinstance(other, bool)
253
+ or isinstance(other, scat)
254
+ )
255
+
201
256
  if isinstance(other, scat):
202
- return scat(self.dodiv(other.P00, self.P00), \
203
- self.dodiv(other.S0 , self.S0), \
204
- self.dodiv(other.S1 , self.S1), \
205
- self.dodiv(other.S2 , self.S2), \
206
- self.dodiv(other.S2L , self.S2L), \
207
- self.j1,self.j2,backend=self.backend)
257
+ return scat(
258
+ self.dodiv(other.P00, self.P00),
259
+ self.dodiv(other.S0, self.S0),
260
+ self.dodiv(other.S1, self.S1),
261
+ self.dodiv(other.S2, self.S2),
262
+ self.dodiv(other.S2L, self.S2L),
263
+ self.j1,
264
+ self.j2,
265
+ backend=self.backend,
266
+ )
208
267
  else:
209
- return scat((other/ self.P00), \
210
- (other / self.S0), \
211
- (other / self.S1), \
212
- (other / self.S2), \
213
- (other / self.S2L), \
214
- self.j1,self.j2,backend=self.backend)
215
-
216
- def __sub__(self,other):
217
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
218
- isinstance(other, bool) or isinstance(other, scat)
219
-
268
+ return scat(
269
+ (other / self.P00),
270
+ (other / self.S0),
271
+ (other / self.S1),
272
+ (other / self.S2),
273
+ (other / self.S2L),
274
+ self.j1,
275
+ self.j2,
276
+ backend=self.backend,
277
+ )
278
+
279
+ def __sub__(self, other):
280
+ assert (
281
+ isinstance(other, float)
282
+ or isinstance(other, np.float32)
283
+ or isinstance(other, int)
284
+ or isinstance(other, bool)
285
+ or isinstance(other, scat)
286
+ )
287
+
220
288
  if isinstance(other, scat):
221
- return scat(self.domin(self.P00, other.P00), \
222
- self.domin(self.S0, other.S0), \
223
- self.domin(self.S1, other.S1), \
224
- self.domin(self.S2, other.S2), \
225
- self.domin(self.S2L, other.S2L), \
226
- self.j1,self.j2,backend=self.backend)
289
+ return scat(
290
+ self.domin(self.P00, other.P00),
291
+ self.domin(self.S0, other.S0),
292
+ self.domin(self.S1, other.S1),
293
+ self.domin(self.S2, other.S2),
294
+ self.domin(self.S2L, other.S2L),
295
+ self.j1,
296
+ self.j2,
297
+ backend=self.backend,
298
+ )
227
299
  else:
228
- return scat((self.P00- other), \
229
- (self.S0- other), \
230
- (self.S1- other), \
231
- (self.S2- other), \
232
- (self.S2L- other), \
233
- self.j1,self.j2,backend=self.backend)
234
-
235
- def __rsub__(self,other):
236
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
237
- isinstance(other, bool) or isinstance(other, scat)
238
-
300
+ return scat(
301
+ (self.P00 - other),
302
+ (self.S0 - other),
303
+ (self.S1 - other),
304
+ (self.S2 - other),
305
+ (self.S2L - other),
306
+ self.j1,
307
+ self.j2,
308
+ backend=self.backend,
309
+ )
310
+
311
+ def __rsub__(self, other):
312
+ assert (
313
+ isinstance(other, float)
314
+ or isinstance(other, np.float32)
315
+ or isinstance(other, int)
316
+ or isinstance(other, bool)
317
+ or isinstance(other, scat)
318
+ )
319
+
239
320
  if isinstance(other, scat):
240
- return scat(self.domin(other.P00,self.P00), \
241
- self.domin(other.S0, self.S0), \
242
- self.domin(other.S1, self.S1), \
243
- self.domin(other.S2, self.S2), \
244
- self.domin(other.S2L, self.S2L), \
245
- self.j1,self.j2,backend=self.backend)
321
+ return scat(
322
+ self.domin(other.P00, self.P00),
323
+ self.domin(other.S0, self.S0),
324
+ self.domin(other.S1, self.S1),
325
+ self.domin(other.S2, self.S2),
326
+ self.domin(other.S2L, self.S2L),
327
+ self.j1,
328
+ self.j2,
329
+ backend=self.backend,
330
+ )
246
331
  else:
247
- return scat((other-self.P00), \
248
- (other-self.S0), \
249
- (other-self.S1), \
250
- (other-self.S2), \
251
- (other-self.S2L), \
252
- self.j1,self.j2,backend=self.backend)
253
-
254
- def __mul__(self,other):
255
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
256
- isinstance(other, bool) or isinstance(other, scat)
257
-
332
+ return scat(
333
+ (other - self.P00),
334
+ (other - self.S0),
335
+ (other - self.S1),
336
+ (other - self.S2),
337
+ (other - self.S2L),
338
+ self.j1,
339
+ self.j2,
340
+ backend=self.backend,
341
+ )
342
+
343
+ def __mul__(self, other):
344
+ assert (
345
+ isinstance(other, float)
346
+ or isinstance(other, np.float32)
347
+ or isinstance(other, int)
348
+ or isinstance(other, bool)
349
+ or isinstance(other, scat)
350
+ )
351
+
258
352
  if isinstance(other, scat):
259
- return scat(self.domult(self.P00, other.P00), \
260
- self.domult(self.S0, other.S0), \
261
- self.domult(self.S1, other.S1), \
262
- self.domult(self.S2, other.S2), \
263
- self.domult(self.S2L, other.S2L), \
264
- self.j1,self.j2,backend=self.backend)
353
+ return scat(
354
+ self.domult(self.P00, other.P00),
355
+ self.domult(self.S0, other.S0),
356
+ self.domult(self.S1, other.S1),
357
+ self.domult(self.S2, other.S2),
358
+ self.domult(self.S2L, other.S2L),
359
+ self.j1,
360
+ self.j2,
361
+ backend=self.backend,
362
+ )
265
363
  else:
266
- return scat((self.P00* other), \
267
- (self.S0* other), \
268
- (self.S1* other), \
269
- (self.S2* other), \
270
- (self.S2L* other), \
271
- self.j1,self.j2,backend=self.backend)
364
+ return scat(
365
+ (self.P00 * other),
366
+ (self.S0 * other),
367
+ (self.S1 * other),
368
+ (self.S2 * other),
369
+ (self.S2L * other),
370
+ self.j1,
371
+ self.j2,
372
+ backend=self.backend,
373
+ )
374
+
272
375
  def relu(self):
273
- return scat(self.backend.bk_relu(self.P00),
274
- self.backend.bk_relu(self.S0),
275
- self.backend.bk_relu(self.S1),
276
- self.backend.bk_relu(self.S2),
277
- self.backend.bk_relu(self.S2L), \
278
- self.j1,self.j2,backend=self.backend)
279
-
280
-
281
- def __rmul__(self,other):
282
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
283
- isinstance(other, bool) or isinstance(other, scat)
284
-
376
+ return scat(
377
+ self.backend.bk_relu(self.P00),
378
+ self.backend.bk_relu(self.S0),
379
+ self.backend.bk_relu(self.S1),
380
+ self.backend.bk_relu(self.S2),
381
+ self.backend.bk_relu(self.S2L),
382
+ self.j1,
383
+ self.j2,
384
+ backend=self.backend,
385
+ )
386
+
387
+ def __rmul__(self, other):
388
+ assert (
389
+ isinstance(other, float)
390
+ or isinstance(other, np.float32)
391
+ or isinstance(other, int)
392
+ or isinstance(other, bool)
393
+ or isinstance(other, scat)
394
+ )
395
+
285
396
  if isinstance(other, scat):
286
- return scat(self.domult(self.P00, other.P00), \
287
- self.domult(self.S0, other.S0), \
288
- self.domult(self.S1, other.S1), \
289
- self.domult(self.S2, other.S2), \
290
- self.domult(self.S2L, other.S2L), \
291
- self.j1,self.j2,backend=self.backend)
397
+ return scat(
398
+ self.domult(self.P00, other.P00),
399
+ self.domult(self.S0, other.S0),
400
+ self.domult(self.S1, other.S1),
401
+ self.domult(self.S2, other.S2),
402
+ self.domult(self.S2L, other.S2L),
403
+ self.j1,
404
+ self.j2,
405
+ backend=self.backend,
406
+ )
292
407
  else:
293
- return scat((self.P00* other), \
294
- (self.S0* other), \
295
- (self.S1* other), \
296
- (self.S2* other), \
297
- (self.S2L* other), \
298
- self.j1,self.j2,backend=self.backend)
299
-
300
- def l1_abs(self,x):
301
- y=self.get_np(x)
408
+ return scat(
409
+ (self.P00 * other),
410
+ (self.S0 * other),
411
+ (self.S1 * other),
412
+ (self.S2 * other),
413
+ (self.S2L * other),
414
+ self.j1,
415
+ self.j2,
416
+ backend=self.backend,
417
+ )
418
+
419
+ def l1_abs(self, x):
420
+ y = self.get_np(x)
302
421
  if self.backend.bk_is_complex(y):
303
- tmp=y.real*y.real+y.imag*y.imag
304
- tmp=np.sign(tmp)*np.sqrt(np.fabs(tmp))
305
- y=tmp
306
-
307
- return(y)
308
-
309
- def plot(self,name=None,hold=True,color='blue',lw=1,legend=True):
422
+ tmp = y.real * y.real + y.imag * y.imag
423
+ tmp = np.sign(tmp) * np.sqrt(np.fabs(tmp))
424
+ y = tmp
425
+
426
+ return y
427
+
428
+ def plot(self, name=None, hold=True, color="blue", lw=1, legend=True):
310
429
 
311
430
  import matplotlib.pyplot as plt
312
431
 
313
- j1,j2=self.get_j_idx()
314
-
432
+ j1, j2 = self.get_j_idx()
433
+
315
434
  if name is None:
316
- name=''
435
+ name = ""
317
436
 
318
437
  if hold:
319
- plt.figure(figsize=(16,8))
320
-
321
- test=None
438
+ plt.figure(figsize=(16, 8))
439
+
440
+ test = None
322
441
  plt.subplot(2, 2, 1)
323
- tmp=abs(self.get_np(self.S1))
324
- if len(tmp.shape)==4:
442
+ tmp = abs(self.get_np(self.S1))
443
+ if len(tmp.shape) == 4:
325
444
  for k in range(tmp.shape[3]):
326
445
  for i1 in range(tmp.shape[0]):
327
446
  for i2 in range(tmp.shape[1]):
328
447
  if test is None:
329
- test=1
330
- plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
448
+ test = 1
449
+ plt.plot(
450
+ tmp[i1, i2, :, k],
451
+ color=color,
452
+ label=r"%s $S_{1}$" % (name),
453
+ lw=lw,
454
+ )
331
455
  else:
332
- plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
456
+ plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
333
457
  else:
334
458
  for k in range(tmp.shape[2]):
335
459
  for i1 in range(tmp.shape[0]):
336
460
  if test is None:
337
- test=1
338
- plt.plot(tmp[i1,:,k],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
461
+ test = 1
462
+ plt.plot(
463
+ tmp[i1, :, k],
464
+ color=color,
465
+ label=r"%s $S_{1}$" % (name),
466
+ lw=lw,
467
+ )
339
468
  else:
340
- plt.plot(tmp[i1,:,k],color=color, lw=lw)
341
- plt.yscale('log')
342
- plt.ylabel('S1')
343
- plt.xlabel(r'$j_{1}$')
469
+ plt.plot(tmp[i1, :, k], color=color, lw=lw)
470
+ plt.yscale("log")
471
+ plt.ylabel("S1")
472
+ plt.xlabel(r"$j_{1}$")
344
473
  plt.legend()
345
474
 
346
- test=None
475
+ test = None
347
476
  plt.subplot(2, 2, 2)
348
- tmp=abs(self.get_np(self.P00))
349
- if len(tmp.shape)==4:
477
+ tmp = abs(self.get_np(self.P00))
478
+ if len(tmp.shape) == 4:
350
479
  for k in range(tmp.shape[3]):
351
480
  for i1 in range(tmp.shape[0]):
352
481
  for i2 in range(tmp.shape[1]):
353
482
  if test is None:
354
- test=1
355
- plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
483
+ test = 1
484
+ plt.plot(
485
+ tmp[i1, i2, :, k],
486
+ color=color,
487
+ label=r"%s $P_{00}$" % (name),
488
+ lw=lw,
489
+ )
356
490
  else:
357
- plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
491
+ plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
358
492
  else:
359
493
  for k in range(tmp.shape[2]):
360
494
  for i1 in range(tmp.shape[0]):
361
495
  if test is None:
362
- test=1
363
- plt.plot(tmp[i1,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
496
+ test = 1
497
+ plt.plot(
498
+ tmp[i1, :, k],
499
+ color=color,
500
+ label=r"%s $P_{00}$" % (name),
501
+ lw=lw,
502
+ )
364
503
  else:
365
- plt.plot(tmp[i1,:,k],color=color, lw=lw)
366
- plt.yscale('log')
367
- plt.ylabel('P00')
368
- plt.xlabel(r'$j_{1}$')
504
+ plt.plot(tmp[i1, :, k], color=color, lw=lw)
505
+ plt.yscale("log")
506
+ plt.ylabel("P00")
507
+ plt.xlabel(r"$j_{1}$")
369
508
  plt.legend()
370
-
371
- ax1=plt.subplot(2, 2, 3)
509
+
510
+ ax1 = plt.subplot(2, 2, 3)
372
511
  ax2 = ax1.twiny()
373
- n=0
374
- tmp=abs(self.get_np(self.S2))
375
- lname=r'%s $S_{2}$' % (name)
376
- ax1.set_ylabel(r'$S_{2}$ [L1 norm]')
377
- test=None
378
- tabx=[]
379
- tabnx=[]
380
- tab2x=[]
381
- tab2nx=[]
382
- if len(tmp.shape)==5:
512
+ n = 0
513
+ tmp = abs(self.get_np(self.S2))
514
+ lname = r"%s $S_{2}$" % (name)
515
+ ax1.set_ylabel(r"$S_{2}$ [L1 norm]")
516
+ test = None
517
+ tabx = []
518
+ tabnx = []
519
+ tab2x = []
520
+ tab2nx = []
521
+ if len(tmp.shape) == 5:
383
522
  for i0 in range(tmp.shape[0]):
384
523
  for i1 in range(tmp.shape[1]):
385
- for i2 in range(j1.max()+1):
524
+ for i2 in range(j1.max() + 1):
386
525
  for i3 in range(tmp.shape[3]):
387
526
  for i4 in range(tmp.shape[4]):
388
- if j2[j1==i2].shape[0]==1:
389
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4],'.', \
390
- color=color, lw=lw)
527
+ if j2[j1 == i2].shape[0] == 1:
528
+ ax1.plot(
529
+ j2[j1 == i2] + n,
530
+ tmp[i0, i1, j1 == i2, i3, i4],
531
+ ".",
532
+ color=color,
533
+ lw=lw,
534
+ )
391
535
  else:
392
536
  if legend and test is None:
393
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
394
- color=color, label=lname, lw=lw)
395
- test=1
396
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
397
- color=color, lw=lw)
398
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
399
- tabx=tabx+[k+n for k in j2[j1==i2]]
400
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
401
- tab2nx=tab2nx+['%d'%(i2)]
402
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
403
- n=n+j2[j1==i2].shape[0]-1
537
+ ax1.plot(
538
+ j2[j1 == i2] + n,
539
+ tmp[i0, i1, j1 == i2, i3, i4],
540
+ color=color,
541
+ label=lname,
542
+ lw=lw,
543
+ )
544
+ test = 1
545
+ ax1.plot(
546
+ j2[j1 == i2] + n,
547
+ tmp[i0, i1, j1 == i2, i3, i4],
548
+ color=color,
549
+ lw=lw,
550
+ )
551
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
552
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
553
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
554
+ tab2nx = tab2nx + ["%d" % (i2)]
555
+ ax1.axvline(
556
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
557
+ )
558
+ n = n + j2[j1 == i2].shape[0] - 1
404
559
  else:
405
560
  for i0 in range(tmp.shape[0]):
406
- for i2 in range(j1.max()+1):
561
+ for i2 in range(j1.max() + 1):
407
562
  for i3 in range(tmp.shape[2]):
408
563
  for i4 in range(tmp.shape[3]):
409
- if j2[j1==i2].shape[0]==1:
410
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4],'.', \
411
- color=color, lw=lw)
564
+ if j2[j1 == i2].shape[0] == 1:
565
+ ax1.plot(
566
+ j2[j1 == i2] + n,
567
+ tmp[i0, j1 == i2, i3, i4],
568
+ ".",
569
+ color=color,
570
+ lw=lw,
571
+ )
412
572
  else:
413
573
  if legend and test is None:
414
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
415
- color=color, label=lname, lw=lw)
416
- test=1
417
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
418
- color=color, lw=lw)
419
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
420
- tabx=tabx+[k+n for k in j2[j1==i2]]
421
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
422
- tab2nx=tab2nx+['%d'%(i2)]
423
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
424
- n=n+j2[j1==i2].shape[0]-1
425
- plt.yscale('log')
426
- ax1.set_xlim(0,n+2)
574
+ ax1.plot(
575
+ j2[j1 == i2] + n,
576
+ tmp[i0, j1 == i2, i3, i4],
577
+ color=color,
578
+ label=lname,
579
+ lw=lw,
580
+ )
581
+ test = 1
582
+ ax1.plot(
583
+ j2[j1 == i2] + n,
584
+ tmp[i0, j1 == i2, i3, i4],
585
+ color=color,
586
+ lw=lw,
587
+ )
588
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
589
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
590
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
591
+ tab2nx = tab2nx + ["%d" % (i2)]
592
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
593
+ n = n + j2[j1 == i2].shape[0] - 1
594
+ plt.yscale("log")
595
+ ax1.set_xlim(0, n + 2)
427
596
  ax1.set_xticks(tabx)
428
- ax1.set_xticklabels(tabnx,fontsize=6)
429
- ax1.set_xlabel(r"$j_{2}$ ",fontsize=6)
430
-
597
+ ax1.set_xticklabels(tabnx, fontsize=6)
598
+ ax1.set_xlabel(r"$j_{2}$ ", fontsize=6)
599
+
431
600
  # Move twinned axis ticks and label from top to bottom
432
601
  ax2.xaxis.set_ticks_position("bottom")
433
602
  ax2.xaxis.set_label_position("bottom")
@@ -435,7 +604,7 @@ class scat:
435
604
  # Offset the twin axis below the host
436
605
  ax2.spines["bottom"].set_position(("axes", -0.15))
437
606
 
438
- # Turn on the frame for the twin axis, but then hide all
607
+ # Turn on the frame for the twin axis, but then hide all
439
608
  # but the bottom spine
440
609
  ax2.set_frame_on(True)
441
610
  ax2.patch.set_visible(False)
@@ -443,72 +612,102 @@ class scat:
443
612
  for sp in ax2.spines.values():
444
613
  sp.set_visible(False)
445
614
  ax2.spines["bottom"].set_visible(True)
446
- ax2.set_xlim(0,n+2)
615
+ ax2.set_xlim(0, n + 2)
447
616
  ax2.set_xticks(tab2x)
448
- ax2.set_xticklabels(tab2nx,fontsize=6)
449
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
617
+ ax2.set_xticklabels(tab2nx, fontsize=6)
618
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
450
619
  ax1.legend(frameon=0)
451
-
452
- ax1=plt.subplot(2, 2, 4)
620
+
621
+ ax1 = plt.subplot(2, 2, 4)
453
622
  ax2 = ax1.twiny()
454
- n=0
455
- tmp=abs(self.get_np(self.S2L))
456
- lname=r'%s $S2_{2}$' % (name)
457
- ax1.set_ylabel(r'$S_{2}$ [L2 norm]')
458
- test=None
459
- tabx=[]
460
- tabnx=[]
461
- tab2x=[]
462
- tab2nx=[]
463
- if len(tmp.shape)==5:
623
+ n = 0
624
+ tmp = abs(self.get_np(self.S2L))
625
+ lname = r"%s $S2_{2}$" % (name)
626
+ ax1.set_ylabel(r"$S_{2}$ [L2 norm]")
627
+ test = None
628
+ tabx = []
629
+ tabnx = []
630
+ tab2x = []
631
+ tab2nx = []
632
+ if len(tmp.shape) == 5:
464
633
  for i0 in range(tmp.shape[0]):
465
634
  for i1 in range(tmp.shape[1]):
466
- for i2 in range(j1.max()+1):
635
+ for i2 in range(j1.max() + 1):
467
636
  for i3 in range(tmp.shape[3]):
468
637
  for i4 in range(tmp.shape[4]):
469
- if j2[j1==i2].shape[0]==1:
470
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4],'.', \
471
- color=color, lw=lw)
638
+ if j2[j1 == i2].shape[0] == 1:
639
+ ax1.plot(
640
+ j2[j1 == i2] + n,
641
+ tmp[i0, i1, j1 == i2, i3, i4],
642
+ ".",
643
+ color=color,
644
+ lw=lw,
645
+ )
472
646
  else:
473
647
  if legend and test is None:
474
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
475
- color=color, label=lname, lw=lw)
476
- test=1
477
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
478
- color=color, lw=lw)
479
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
480
- tabx=tabx+[k+n for k in j2[j1==i2]]
481
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
482
- tab2nx=tab2nx+['%d'%(i2)]
483
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
484
- n=n+j2[j1==i2].shape[0]-1
648
+ ax1.plot(
649
+ j2[j1 == i2] + n,
650
+ tmp[i0, i1, j1 == i2, i3, i4],
651
+ color=color,
652
+ label=lname,
653
+ lw=lw,
654
+ )
655
+ test = 1
656
+ ax1.plot(
657
+ j2[j1 == i2] + n,
658
+ tmp[i0, i1, j1 == i2, i3, i4],
659
+ color=color,
660
+ lw=lw,
661
+ )
662
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
663
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
664
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
665
+ tab2nx = tab2nx + ["%d" % (i2)]
666
+ ax1.axvline(
667
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
668
+ )
669
+ n = n + j2[j1 == i2].shape[0] - 1
485
670
  else:
486
671
  for i0 in range(tmp.shape[0]):
487
- for i2 in range(j1.max()+1):
672
+ for i2 in range(j1.max() + 1):
488
673
  for i3 in range(tmp.shape[2]):
489
674
  for i4 in range(tmp.shape[3]):
490
- if j2[j1==i2].shape[0]==1:
491
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4],'.', \
492
- color=color, lw=lw)
675
+ if j2[j1 == i2].shape[0] == 1:
676
+ ax1.plot(
677
+ j2[j1 == i2] + n,
678
+ tmp[i0, j1 == i2, i3, i4],
679
+ ".",
680
+ color=color,
681
+ lw=lw,
682
+ )
493
683
  else:
494
684
  if legend and test is None:
495
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
496
- color=color, label=lname, lw=lw)
497
- test=1
498
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
499
- color=color, lw=lw)
500
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
501
- tabx=tabx+[k+n for k in j2[j1==i2]]
502
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
503
- tab2nx=tab2nx+['%d'%(i2)]
504
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
505
- n=n+j2[j1==i2].shape[0]-1
506
- plt.yscale('log')
507
- ax1.set_xlim(-1,n+3)
685
+ ax1.plot(
686
+ j2[j1 == i2] + n,
687
+ tmp[i0, j1 == i2, i3, i4],
688
+ color=color,
689
+ label=lname,
690
+ lw=lw,
691
+ )
692
+ test = 1
693
+ ax1.plot(
694
+ j2[j1 == i2] + n,
695
+ tmp[i0, j1 == i2, i3, i4],
696
+ color=color,
697
+ lw=lw,
698
+ )
699
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
700
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
701
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
702
+ tab2nx = tab2nx + ["%d" % (i2)]
703
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
704
+ n = n + j2[j1 == i2].shape[0] - 1
705
+ plt.yscale("log")
706
+ ax1.set_xlim(-1, n + 3)
508
707
  ax1.set_xticks(tabx)
509
- ax1.set_xticklabels(tabnx,fontsize=6)
510
- ax1.set_xlabel(r"$j_{2}$",fontsize=6)
511
-
708
+ ax1.set_xticklabels(tabnx, fontsize=6)
709
+ ax1.set_xlabel(r"$j_{2}$", fontsize=6)
710
+
512
711
  # Move twinned axis ticks and label from top to bottom
513
712
  ax2.xaxis.set_ticks_position("bottom")
514
713
  ax2.xaxis.set_label_position("bottom")
@@ -516,7 +715,7 @@ class scat:
516
715
  # Offset the twin axis below the host
517
716
  ax2.spines["bottom"].set_position(("axes", -0.15))
518
717
 
519
- # Turn on the frame for the twin axis, but then hide all
718
+ # Turn on the frame for the twin axis, but then hide all
520
719
  # but the bottom spine
521
720
  ax2.set_frame_on(True)
522
721
  ax2.patch.set_visible(False)
@@ -524,252 +723,351 @@ class scat:
524
723
  for sp in ax2.spines.values():
525
724
  sp.set_visible(False)
526
725
  ax2.spines["bottom"].set_visible(True)
527
- ax2.set_xlim(0,n+3)
726
+ ax2.set_xlim(0, n + 3)
528
727
  ax2.set_xticks(tab2x)
529
- ax2.set_xticklabels(tab2nx,fontsize=6)
530
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
728
+ ax2.set_xticklabels(tab2nx, fontsize=6)
729
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
531
730
  ax1.legend(frameon=0)
532
-
533
- def save(self,filename):
534
- outlist=[self.get_S0().numpy(), \
535
- self.get_S1().numpy(), \
536
- self.get_S2().numpy(), \
537
- self.get_S2L().numpy(), \
538
- self.get_P00().numpy(), \
539
- self.j1, \
540
- self.j2]
541
-
542
- myout=open("%s.pkl"%(filename),"wb")
543
- pickle.dump(outlist,myout)
731
+
732
+ def save(self, filename):
733
+ outlist = [
734
+ self.get_S0().numpy(),
735
+ self.get_S1().numpy(),
736
+ self.get_S2().numpy(),
737
+ self.get_S2L().numpy(),
738
+ self.get_P00().numpy(),
739
+ self.j1,
740
+ self.j2,
741
+ ]
742
+
743
+ myout = open("%s.pkl" % (filename), "wb")
744
+ pickle.dump(outlist, myout)
544
745
  myout.close()
545
746
 
546
-
547
- def read(self,filename):
548
-
549
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
550
- return scat(outlist[4],outlist[0],outlist[1],outlist[2],outlist[3],outlist[5],outlist[6],backend=bk.foscat_backend('numpy'))
551
-
552
- def get_np(self,x):
747
+ def read(self, filename):
748
+
749
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
750
+ return scat(
751
+ outlist[4],
752
+ outlist[0],
753
+ outlist[1],
754
+ outlist[2],
755
+ outlist[3],
756
+ outlist[5],
757
+ outlist[6],
758
+ backend=bk.foscat_backend("numpy"),
759
+ )
760
+
761
+ def get_np(self, x):
553
762
  if isinstance(x, np.ndarray):
554
763
  return x
555
764
  else:
556
765
  return x.numpy()
557
766
 
558
767
  def std(self):
559
- return np.sqrt(((abs(self.get_np(self.S0)).std())**2+ \
560
- (abs(self.get_np(self.S1)).std())**2+ \
561
- (abs(self.get_np(self.S2)).std())**2+ \
562
- (abs(self.get_np(self.S2L)).std())**2+ \
563
- (abs(self.get_np(self.P00)).std())**2)/4)
768
+ return np.sqrt(
769
+ (
770
+ (abs(self.get_np(self.S0)).std()) ** 2
771
+ + (abs(self.get_np(self.S1)).std()) ** 2
772
+ + (abs(self.get_np(self.S2)).std()) ** 2
773
+ + (abs(self.get_np(self.S2L)).std()) ** 2
774
+ + (abs(self.get_np(self.P00)).std()) ** 2
775
+ )
776
+ / 4
777
+ )
564
778
 
565
779
  def mean(self):
566
- return abs(self.get_np(self.S0).mean()+ \
567
- self.get_np(self.S1).mean()+ \
568
- self.get_np(self.S2).mean()+ \
569
- self.get_np(self.S2L).mean()+ \
570
- self.get_np(self.P00).mean())/3
780
+ return (
781
+ abs(
782
+ self.get_np(self.S0).mean()
783
+ + self.get_np(self.S1).mean()
784
+ + self.get_np(self.S2).mean()
785
+ + self.get_np(self.S2L).mean()
786
+ + self.get_np(self.P00).mean()
787
+ )
788
+ / 3
789
+ )
571
790
 
572
791
  def sqrt(self):
573
792
 
793
+ s0 = self.backend.bk_sqrt(self.S0)
794
+ s1 = self.backend.bk_sqrt(self.S1)
795
+ p00 = self.backend.bk_sqrt(self.P00)
796
+ s2 = self.backend.bk_sqrt(self.S2)
797
+ s2L = self.backend.bk_sqrt(self.S2L)
574
798
 
575
- s0 =self.backend.bk_sqrt(self.S0)
576
- s1 =self.backend.bk_sqrt(self.S1)
577
- p00=self.backend.bk_sqrt(self.P00)
578
- s2 =self.backend.bk_sqrt(self.S2)
579
- s2L=self.backend.bk_sqrt(self.S2L)
580
-
581
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
582
-
799
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
583
800
 
584
801
  def L1(self):
585
802
 
803
+ s0 = self.backend.bk_L1(self.S0)
804
+ s1 = self.backend.bk_L1(self.S1)
805
+ p00 = self.backend.bk_L1(self.P00)
806
+ s2 = self.backend.bk_L1(self.S2)
807
+ s2L = self.backend.bk_L1(self.S2L)
808
+
809
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
586
810
 
587
- s0 =self.backend.bk_L1(self.S0)
588
- s1 =self.backend.bk_L1(self.S1)
589
- p00=self.backend.bk_L1(self.P00)
590
- s2 =self.backend.bk_L1(self.S2)
591
- s2L=self.backend.bk_L1(self.S2L)
592
-
593
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
594
-
595
811
  def square_comp(self):
596
812
 
813
+ s0 = self.backend.bk_square_comp(self.S0)
814
+ s1 = self.backend.bk_square_comp(self.S1)
815
+ p00 = self.backend.bk_square_comp(self.P00)
816
+ s2 = self.backend.bk_square_comp(self.S2)
817
+ s2L = self.backend.bk_square_comp(self.S2L)
818
+
819
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
597
820
 
598
- s0 =self.backend.bk_square_comp(self.S0)
599
- s1 =self.backend.bk_square_comp(self.S1)
600
- p00=self.backend.bk_square_comp(self.P00)
601
- s2 =self.backend.bk_square_comp(self.S2)
602
- s2L=self.backend.bk_square_comp(self.S2L)
603
-
604
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
605
-
606
- def iso_mean(self,repeat=False):
607
- shape=list(self.S2.shape)
608
- norient=self.S1.shape[2]
821
+ def iso_mean(self, repeat=False):
822
+ shape = list(self.S2.shape)
823
+ norient = self.S1.shape[2]
609
824
 
610
- S1 = self.backend.bk_reduce_mean(self.S1,2)
825
+ S1 = self.backend.bk_reduce_mean(self.S1, 2)
611
826
  if repeat:
612
- S1=self.backend.bk_reshape(self.backend.bk_repeat(S1,shape[2],1),self.S1.shape)
827
+ S1 = self.backend.bk_reshape(
828
+ self.backend.bk_repeat(S1, shape[2], 1), self.S1.shape
829
+ )
613
830
  else:
614
- S1=self.backend.bk_expand_dims(S1,-1)
615
-
831
+ S1 = self.backend.bk_expand_dims(S1, -1)
616
832
 
617
- P00 = self.backend.bk_reduce_mean(self.P00,2)
833
+ P00 = self.backend.bk_reduce_mean(self.P00, 2)
618
834
  if repeat:
619
- P00=self.backend.bk_reshape(self.backend.bk_repeat(P00,shape[2],1),self.S1.shape)
835
+ P00 = self.backend.bk_reshape(
836
+ self.backend.bk_repeat(P00, shape[2], 1), self.S1.shape
837
+ )
620
838
  else:
621
- P00=self.backend.bk_expand_dims(P00,-1)
839
+ P00 = self.backend.bk_expand_dims(P00, -1)
622
840
 
623
841
  if norient not in self.backend._iso_orient:
624
842
  self.backend.calc_iso_orient(norient)
625
-
843
+
626
844
  if self.backend.bk_is_complex(self.S2):
627
- lmat = self.backend._iso_orient_C[norient]
845
+ lmat = self.backend._iso_orient_C[norient]
628
846
  lmat_T = self.backend._iso_orient_C_T[norient]
629
847
  else:
630
- lmat = self.backend._iso_orient[norient]
848
+ lmat = self.backend._iso_orient[norient]
631
849
  lmat_T = self.backend._iso_orient_T[norient]
632
-
633
- S2=self.backend.bk_reshape(
634
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2,[shape[0],shape[1],norient*norient]),lmat),
635
- [shape[0],shape[1],norient])
636
- S2L=self.backend.bk_reshape(
637
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2L,[shape[0],shape[1],norient*norient]),lmat),
638
- [shape[0],shape[1],norient])
639
-
850
+
851
+ S2 = self.backend.bk_reshape(
852
+ self.backend.backend.matmul(
853
+ self.backend.bk_reshape(
854
+ self.S2, [shape[0], shape[1], norient * norient]
855
+ ),
856
+ lmat,
857
+ ),
858
+ [shape[0], shape[1], norient],
859
+ )
860
+ S2L = self.backend.bk_reshape(
861
+ self.backend.backend.matmul(
862
+ self.backend.bk_reshape(
863
+ self.S2L, [shape[0], shape[1], norient * norient]
864
+ ),
865
+ lmat,
866
+ ),
867
+ [shape[0], shape[1], norient],
868
+ )
869
+
640
870
  if repeat:
641
-
642
- S2=self.backend.bk_reshape(
643
- self.backend.backend.matmul(self.backend.bk_reshape(S2,[shape[0]*shape[1],norient]),lmat_T),
644
- self.S2.shape)
645
- S2L=self.backend.bk_reshape(
646
- self.backend.backend.matmul(self.backend.bk_reshape(S2L,[shape[0]*shape[1],norient]),lmat_T),
647
- self.S2.shape)
871
+
872
+ S2 = self.backend.bk_reshape(
873
+ self.backend.backend.matmul(
874
+ self.backend.bk_reshape(S2, [shape[0] * shape[1], norient]), lmat_T
875
+ ),
876
+ self.S2.shape,
877
+ )
878
+ S2L = self.backend.bk_reshape(
879
+ self.backend.backend.matmul(
880
+ self.backend.bk_reshape(S2L, [shape[0] * shape[1], norient]), lmat_T
881
+ ),
882
+ self.S2.shape,
883
+ )
648
884
  else:
649
- S2=self.backend.bk_expand_dims(S2,-1)
650
- S2L=self.backend.bk_expand_dims(S2L,-1)
885
+ S2 = self.backend.bk_expand_dims(S2, -1)
886
+ S2L = self.backend.bk_expand_dims(S2L, -1)
651
887
 
652
- return scat(P00,self.S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
888
+ return scat(P00, self.S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
653
889
 
654
-
655
- def fft_ang(self,nharm=1,imaginary=False):
656
- shape=list(self.S2.shape)
657
- norient=self.S1.shape[2]
890
+ def fft_ang(self, nharm=1, imaginary=False):
891
+ shape = list(self.S2.shape)
892
+ norient = self.S1.shape[2]
658
893
 
659
- nout=1+nharm
894
+ nout = 1 + nharm
660
895
  if imaginary:
661
- nout=1+nharm*2
662
-
663
- if (norient,nharm) not in self.backend._fft_1_orient:
664
- self.backend.calc_fft_orient(norient,nharm,imaginary)
665
-
896
+ nout = 1 + nharm * 2
897
+
898
+ if (norient, nharm) not in self.backend._fft_1_orient:
899
+ self.backend.calc_fft_orient(norient, nharm, imaginary)
900
+
666
901
  if self.backend.bk_is_complex(self.S1):
667
- lmat = self.backend._fft_1_orient_C[(norient,nharm,imaginary)]
902
+ lmat = self.backend._fft_1_orient_C[(norient, nharm, imaginary)]
668
903
  else:
669
- lmat = self.backend._fft_1_orient[(norient,nharm,imaginary)]
670
-
671
- S1=self.backend.bk_reshape(
672
- self.backend.backend.matmul(self.backend.bk_reshape(self.S1,[self.S1.shape[0],self.S1.shape[1],norient]),lmat),
673
- [self.S1.shape[0],self.S1.shape[1],nout])
674
-
675
- P00=self.backend.bk_reshape(
676
- self.backend.backend.matmul(self.backend.bk_reshape(self.P00,[self.S1.shape[0],self.S1.shape[1],norient]),lmat),
677
- [self.S1.shape[0],self.S1.shape[1],nout])
678
-
679
-
904
+ lmat = self.backend._fft_1_orient[(norient, nharm, imaginary)]
905
+
906
+ S1 = self.backend.bk_reshape(
907
+ self.backend.backend.matmul(
908
+ self.backend.bk_reshape(
909
+ self.S1, [self.S1.shape[0], self.S1.shape[1], norient]
910
+ ),
911
+ lmat,
912
+ ),
913
+ [self.S1.shape[0], self.S1.shape[1], nout],
914
+ )
915
+
916
+ P00 = self.backend.bk_reshape(
917
+ self.backend.backend.matmul(
918
+ self.backend.bk_reshape(
919
+ self.P00, [self.S1.shape[0], self.S1.shape[1], norient]
920
+ ),
921
+ lmat,
922
+ ),
923
+ [self.S1.shape[0], self.S1.shape[1], nout],
924
+ )
925
+
680
926
  if self.backend.bk_is_complex(self.S2):
681
- lmat = self.backend._fft_2_orient_C[(norient,nharm,imaginary)]
927
+ lmat = self.backend._fft_2_orient_C[(norient, nharm, imaginary)]
682
928
  else:
683
- lmat = self.backend._fft_2_orient[(norient,nharm,imaginary)]
684
-
685
- S2=self.backend.bk_reshape(
686
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2,[shape[0],shape[1],norient*norient]),lmat),
687
- [shape[0],shape[1],nout,nout])
688
- S2L=self.backend.bk_reshape(
689
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2L,[shape[0],shape[1],norient*norient]),lmat),
690
- [shape[0],shape[1],nout,nout])
691
-
692
- return scat(P00,self.S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
693
-
694
-
695
- def iso_std(self,repeat=False):
696
-
697
- val=(self-self.iso_mean(repeat=True)).square_comp()
929
+ lmat = self.backend._fft_2_orient[(norient, nharm, imaginary)]
930
+
931
+ S2 = self.backend.bk_reshape(
932
+ self.backend.backend.matmul(
933
+ self.backend.bk_reshape(
934
+ self.S2, [shape[0], shape[1], norient * norient]
935
+ ),
936
+ lmat,
937
+ ),
938
+ [shape[0], shape[1], nout, nout],
939
+ )
940
+ S2L = self.backend.bk_reshape(
941
+ self.backend.backend.matmul(
942
+ self.backend.bk_reshape(
943
+ self.S2L, [shape[0], shape[1], norient * norient]
944
+ ),
945
+ lmat,
946
+ ),
947
+ [shape[0], shape[1], nout, nout],
948
+ )
949
+
950
+ return scat(P00, self.S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
951
+
952
+ def iso_std(self, repeat=False):
953
+
954
+ val = (self - self.iso_mean(repeat=True)).square_comp()
698
955
  return (val.iso_mean(repeat=repeat)).L1()
699
956
 
700
957
  # ---------------------------------------------−---------
701
- def cleanval(self,x):
702
- x=x.numpy()
703
- x[np.isfinite(x)==False]=np.median(x[np.isfinite(x)])
958
+ def cleanval(self, x):
959
+ x = x.numpy()
960
+ x[~np.isfinite(x)] = np.median(x[np.isfinite(x)])
704
961
  return x
705
962
 
706
963
  def filter_inf(self):
707
- S1 = self.cleanval(self.S1)
708
- S0 = self.cleanval(self.S0)
964
+ S1 = self.cleanval(self.S1)
965
+ S0 = self.cleanval(self.S0)
709
966
  P00 = self.cleanval(self.P00)
710
- S2 = self.cleanval(self.S2)
967
+ S2 = self.cleanval(self.S2)
711
968
  S2L = self.cleanval(self.S2L)
712
969
 
713
- return scat(P00,S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
970
+ return scat(P00, S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
714
971
 
715
972
  # ---------------------------------------------−---------
716
- def interp(self,nscale,extend=False,constant=False,threshold=1E30,use_mask=False):
717
-
718
- if nscale+2>self.S1.shape[1]:
719
- print('Can not *interp* %d with a statistic described over %d'%(nscale,self.S1.shape[1]))
720
- return scat(self.P00,self.S0,self.S1,self.S2,self.S2L,self.j1,self.j2,backend=self.backend)
721
- if isinstance(self.S1,np.ndarray):
722
- s1=self.S1
723
- p0=self.P00
724
- s2=self.S2
725
- s2l=self.S2L
973
+ def interp(
974
+ self, nscale, extend=False, constant=False, threshold=1e30, use_mask=False
975
+ ):
976
+
977
+ if nscale + 2 > self.S1.shape[1]:
978
+ print(
979
+ "Can not *interp* %d with a statistic described over %d"
980
+ % (nscale, self.S1.shape[1])
981
+ )
982
+ return scat(
983
+ self.P00,
984
+ self.S0,
985
+ self.S1,
986
+ self.S2,
987
+ self.S2L,
988
+ self.j1,
989
+ self.j2,
990
+ backend=self.backend,
991
+ )
992
+ if isinstance(self.S1, np.ndarray):
993
+ s1 = self.S1
994
+ p0 = self.P00
995
+ s2 = self.S2
996
+ s2l = self.S2L
726
997
  else:
727
- s1=self.S1.numpy()
728
- p0=self.P00.numpy()
729
- s2=self.S2.numpy()
730
- s2l=self.S2L.numpy()
731
-
732
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
733
-
734
- if isinstance(threshold,scat):
735
- if isinstance(threshold.S1,np.ndarray):
736
- s1th=threshold.S1
737
- p0th=threshold.P00
738
- s2th=threshold.S2
739
- s2lth=threshold.S2L
998
+ s1 = self.S1.numpy()
999
+ p0 = self.P00.numpy()
1000
+ s2 = self.S2.numpy()
1001
+ s2l = self.S2L.numpy()
1002
+
1003
+ print(s1.sum(), p0.sum(), s2.sum(), s2l.sum())
1004
+
1005
+ if isinstance(threshold, scat):
1006
+ if isinstance(threshold.S1, np.ndarray):
1007
+ s1th = threshold.S1
1008
+ p0th = threshold.P00
1009
+ s2th = threshold.S2
1010
+ s2lth = threshold.S2L
740
1011
  else:
741
- s1th=threshold.S1.numpy()
742
- p0th=threshold.P00.numpy()
743
- s2th=threshold.S2.numpy()
744
- s2lth=threshold.S2L.numpy()
1012
+ s1th = threshold.S1.numpy()
1013
+ p0th = threshold.P00.numpy()
1014
+ s2th = threshold.S2.numpy()
1015
+ s2lth = threshold.S2L.numpy()
745
1016
  else:
746
- s1th=threshold+0*s1
747
- p0th=threshold+0*p0
748
- s2th=threshold+0*s2
749
- s2lth=threshold+0*s2l
1017
+ s1th = threshold + 0 * s1
1018
+ p0th = threshold + 0 * p0
1019
+ s2th = threshold + 0 * s2
1020
+ s2lth = threshold + 0 * s2l
750
1021
 
751
1022
  for k in range(nscale):
752
1023
  if constant:
753
- s1[:,nscale-1-k,:]=s1[:,nscale-k,:]
754
- p0[:,nscale-1-k,:]=p0[:,nscale-k,:]
1024
+ s1[:, nscale - 1 - k, :] = s1[:, nscale - k, :]
1025
+ p0[:, nscale - 1 - k, :] = p0[:, nscale - k, :]
755
1026
  else:
756
- idx=np.where((s1[:,nscale+1-k,:]>0)*(s1[:,nscale+2-k,:]>0)*(s1[:,nscale-k,:]<s1th[:,nscale-k,:]))
757
- if len(idx[0])>0:
758
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(s1[idx[0],nscale+1-k,idx[1]])-2*np.log(s1[idx[0],nscale+2-k,idx[1]]))
759
- idx=np.where((s1[:,nscale-k,:]>0)*(s1[:,nscale+1-k,:]>0)*(s1[:,nscale-1-k,:]<s1th[:,nscale-1-k,:]))
760
- if len(idx[0])>0:
761
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(s1[idx[0],nscale-k,idx[1]])-np.log(s1[idx[0],nscale+1-k,idx[1]]))
762
-
763
- idx=np.where((p0[:,nscale+1-k,:]>0)*(p0[:,nscale+2-k,:]>0)*(p0[:,nscale-k,:]<p0th[:,nscale-k,:]))
764
- if len(idx[0])>0:
765
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(p0[idx[0],nscale+1-k,idx[1]])-2*np.log(p0[idx[0],nscale+2-k,idx[1]]))
766
-
767
- idx=np.where((p0[:,nscale-k,:]>0)*(p0[:,nscale+1-k,:]>0)*(p0[:,nscale-1-k,:]<p0th[:,nscale-1-k,:]))
768
- if len(idx[0])>0:
769
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(p0[idx[0],nscale-k,idx[1]])-np.log(p0[idx[0],nscale+1-k,idx[1]]))
770
-
771
-
772
- j1,j2=self.get_j_idx()
1027
+ idx = np.where(
1028
+ (s1[:, nscale + 1 - k, :] > 0)
1029
+ * (s1[:, nscale + 2 - k, :] > 0)
1030
+ * (s1[:, nscale - k, :] < s1th[:, nscale - k, :])
1031
+ )
1032
+ if len(idx[0]) > 0:
1033
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1034
+ 3 * np.log(s1[idx[0], nscale + 1 - k, idx[1]])
1035
+ - 2 * np.log(s1[idx[0], nscale + 2 - k, idx[1]])
1036
+ )
1037
+ idx = np.where(
1038
+ (s1[:, nscale - k, :] > 0)
1039
+ * (s1[:, nscale + 1 - k, :] > 0)
1040
+ * (s1[:, nscale - 1 - k, :] < s1th[:, nscale - 1 - k, :])
1041
+ )
1042
+ if len(idx[0]) > 0:
1043
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1044
+ 2 * np.log(s1[idx[0], nscale - k, idx[1]])
1045
+ - np.log(s1[idx[0], nscale + 1 - k, idx[1]])
1046
+ )
1047
+
1048
+ idx = np.where(
1049
+ (p0[:, nscale + 1 - k, :] > 0)
1050
+ * (p0[:, nscale + 2 - k, :] > 0)
1051
+ * (p0[:, nscale - k, :] < p0th[:, nscale - k, :])
1052
+ )
1053
+ if len(idx[0]) > 0:
1054
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1055
+ 3 * np.log(p0[idx[0], nscale + 1 - k, idx[1]])
1056
+ - 2 * np.log(p0[idx[0], nscale + 2 - k, idx[1]])
1057
+ )
1058
+
1059
+ idx = np.where(
1060
+ (p0[:, nscale - k, :] > 0)
1061
+ * (p0[:, nscale + 1 - k, :] > 0)
1062
+ * (p0[:, nscale - 1 - k, :] < p0th[:, nscale - 1 - k, :])
1063
+ )
1064
+ if len(idx[0]) > 0:
1065
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1066
+ 2 * np.log(p0[idx[0], nscale - k, idx[1]])
1067
+ - np.log(p0[idx[0], nscale + 1 - k, idx[1]])
1068
+ )
1069
+
1070
+ j1, j2 = self.get_j_idx()
773
1071
 
774
1072
  for k in range(nscale):
775
1073
 
@@ -782,669 +1080,942 @@ class scat:
782
1080
  s2l[:,i0]=np.exp(2*np.log(s2l[:,i1])-np.log(s2l[:,i2]))
783
1081
  """
784
1082
 
785
- for l in range(nscale-k):
786
- i0=np.where((j1==nscale-1-k-l)*(j2==nscale-1-k))[0]
787
- i1=np.where((j1==nscale-1-k-l)*(j2==nscale -k))[0]
788
- i2=np.where((j1==nscale-1-k-l)*(j2==nscale+1-k))[0]
789
- i3=np.where((j1==nscale-1-k-l)*(j2==nscale+2-k))[0]
790
-
1083
+ for l_scale in range(nscale - k):
1084
+ i0 = np.where((j1 == nscale - 1 - k - l_scale) * (j2 == nscale - 1 - k))[0]
1085
+ i1 = np.where((j1 == nscale - 1 - k - l_scale) * (j2 == nscale - k))[0]
1086
+ i2 = np.where((j1 == nscale - 1 - k - l_scale) * (j2 == nscale + 1 - k))[0]
1087
+ i3 = np.where((j1 == nscale - 1 - k - l_scale) * (j2 == nscale + 2 - k))[0]
1088
+
791
1089
  if constant:
792
- s2[:,i0]=s2[:,i1]
793
- s2l[:,i0]=s2l[:,i1]
1090
+ s2[:, i0] = s2[:, i1]
1091
+ s2l[:, i0] = s2l[:, i1]
794
1092
  else:
795
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
796
- if len(idx[0])>0:
797
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
798
-
799
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
800
- if len(idx[0])>0:
801
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
802
-
803
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
804
- if len(idx[0])>0:
805
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
806
-
807
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
808
- if len(idx[0])>0:
809
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
810
-
1093
+ idx = np.where(
1094
+ (s2[:, i2] > 0) * (s2[:, i3] > 0) * (s2[:, i2] < s2th[:, i2])
1095
+ )
1096
+ if len(idx[0]) > 0:
1097
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1098
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
1099
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
1100
+ )
1101
+
1102
+ idx = np.where(
1103
+ (s2[:, i1] > 0) * (s2[:, i2] > 0) * (s2[:, i1] < s2th[:, i1])
1104
+ )
1105
+ if len(idx[0]) > 0:
1106
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1107
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1108
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1109
+ )
1110
+
1111
+ idx = np.where(
1112
+ (s2l[:, i2] > 0)
1113
+ * (s2l[:, i3] > 0)
1114
+ * (s2l[:, i2] < s2lth[:, i2])
1115
+ )
1116
+ if len(idx[0]) > 0:
1117
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1118
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1119
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1120
+ )
1121
+
1122
+ idx = np.where(
1123
+ (s2l[:, i1] > 0)
1124
+ * (s2l[:, i2] > 0)
1125
+ * (s2l[:, i1] < s2lth[:, i1])
1126
+ )
1127
+ if len(idx[0]) > 0:
1128
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1129
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1130
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1131
+ )
1132
+
811
1133
  if extend:
812
1134
  for k in range(nscale):
813
- for l in range(1,nscale):
814
- i0=np.where((j1==2*nscale-1-k)*(j2==2*nscale-1-k-l))[0]
815
- i1=np.where((j1==2*nscale-1-k)*(j2==2*nscale -k-l))[0]
816
- i2=np.where((j1==2*nscale-1-k)*(j2==2*nscale+1-k-l))[0]
817
- i3=np.where((j1==2*nscale-1-k)*(j2==2*nscale+2-k-l))[0]
1135
+ for l_scale in range(1, nscale):
1136
+ i0 = np.where(
1137
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale - 1 - k - l_scale)
1138
+ )[0]
1139
+ i1 = np.where(
1140
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale - k - l_scale)
1141
+ )[0]
1142
+ i2 = np.where(
1143
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale + 1 - k - l_scale)
1144
+ )[0]
1145
+ i3 = np.where(
1146
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale + 2 - k - l_scale)
1147
+ )[0]
818
1148
  if constant:
819
- s2[:,i0]=s2[:,i1]
820
- s2l[:,i0]=s2l[:,i1]
1149
+ s2[:, i0] = s2[:, i1]
1150
+ s2l[:, i0] = s2l[:, i1]
821
1151
  else:
822
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
823
- if len(idx[0])>0:
824
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
825
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
826
- if len(idx[0])>0:
827
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
828
-
829
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
830
- if len(idx[0])>0:
831
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
832
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
833
- if len(idx[0])>0:
834
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
835
-
836
- s1[np.isnan(s1)]=0.0
837
- p0[np.isnan(p0)]=0.0
838
- s2[np.isnan(s2)]=0.0
839
- s2l[np.isnan(s2l)]=0.0
840
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
841
-
842
- return scat(self.backend.constant(p0),self.S0,
843
- self.backend.constant(s1),
844
- self.backend.constant(s2),
845
- self.backend.constant(s2l),self.j1,self.j2,backend=self.backend)
1152
+ idx = np.where(
1153
+ (s2[:, i2] > 0)
1154
+ * (s2[:, i3] > 0)
1155
+ * (s2[:, i2] < s2th[:, i2])
1156
+ )
1157
+ if len(idx[0]) > 0:
1158
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1159
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
1160
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
1161
+ )
1162
+ idx = np.where(
1163
+ (s2[:, i1] > 0)
1164
+ * (s2[:, i2] > 0)
1165
+ * (s2[:, i1] < s2th[:, i1])
1166
+ )
1167
+ if len(idx[0]) > 0:
1168
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1169
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1170
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1171
+ )
1172
+
1173
+ idx = np.where(
1174
+ (s2l[:, i2] > 0)
1175
+ * (s2l[:, i3] > 0)
1176
+ * (s2l[:, i2] < s2lth[:, i2])
1177
+ )
1178
+ if len(idx[0]) > 0:
1179
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1180
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1181
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1182
+ )
1183
+ idx = np.where(
1184
+ (s2l[:, i1] > 0)
1185
+ * (s2l[:, i2] > 0)
1186
+ * (s2l[:, i1] < s2lth[:, i1])
1187
+ )
1188
+ if len(idx[0]) > 0:
1189
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1190
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1191
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1192
+ )
1193
+
1194
+ s1[np.isnan(s1)] = 0.0
1195
+ p0[np.isnan(p0)] = 0.0
1196
+ s2[np.isnan(s2)] = 0.0
1197
+ s2l[np.isnan(s2l)] = 0.0
1198
+ print(s1.sum(), p0.sum(), s2.sum(), s2l.sum())
1199
+
1200
+ return scat(
1201
+ self.backend.constant(p0),
1202
+ self.S0,
1203
+ self.backend.constant(s1),
1204
+ self.backend.constant(s2),
1205
+ self.backend.constant(s2l),
1206
+ self.j1,
1207
+ self.j2,
1208
+ backend=self.backend,
1209
+ )
846
1210
 
847
1211
  # ---------------------------------------------−---------
848
1212
  def flatten(self):
849
- if isinstance(self.S1,np.ndarray):
850
- return np.concatenate([self.S0.flatten(),
851
- self.S1.flatten(),
852
- self.P00.flatten(),
853
- self.S2.flatten(),
854
- self.S2L.flatten()],0)
1213
+ if isinstance(self.S1, np.ndarray):
1214
+ return np.concatenate(
1215
+ [
1216
+ self.S0.flatten(),
1217
+ self.S1.flatten(),
1218
+ self.P00.flatten(),
1219
+ self.S2.flatten(),
1220
+ self.S2L.flatten(),
1221
+ ],
1222
+ 0,
1223
+ )
855
1224
  else:
856
- return self.backend.bk_concat([self.backend.bk_flattenR(self.S0),
857
- self.backend.bk_flattenR(self.S1),
858
- self.backend.bk_flattenR(self.P00),
859
- self.backend.bk_flattenR(self.S2),
860
- self.backend.bk_flattenR(self.S2)],axis=0)
1225
+ return self.backend.bk_concat(
1226
+ [
1227
+ self.backend.bk_flattenR(self.S0),
1228
+ self.backend.bk_flattenR(self.S1),
1229
+ self.backend.bk_flattenR(self.P00),
1230
+ self.backend.bk_flattenR(self.S2),
1231
+ self.backend.bk_flattenR(self.S2),
1232
+ ],
1233
+ axis=0,
1234
+ )
861
1235
 
862
1236
  # ---------------------------------------------−---------
863
1237
  def flattenMask(self):
864
- if isinstance(self.S1,np.ndarray):
865
- tmp=np.expand_dims(np.concatenate([self.S1[0].flatten(),
866
- self.P00[0].flatten(),
867
- self.S2[0].flatten(),
868
- self.S2L[0].flatten()],0),0)
869
- for k in range(1,self.P00.shape[0]):
870
- tmp=np.concatenate([tmp,np.expand_dims(np.concatenate([self.S1[k].flatten(),
871
- self.P00[k].flatten(),
872
- self.S2[k].flatten(),
873
- self.S2L[k].flatten()],0),0)],0)
874
-
875
-
876
- return np.concatenate([tmp,np.expand_dims(self.S0,1)],1)
1238
+ if isinstance(self.S1, np.ndarray):
1239
+ tmp = np.expand_dims(
1240
+ np.concatenate(
1241
+ [
1242
+ self.S1[0].flatten(),
1243
+ self.P00[0].flatten(),
1244
+ self.S2[0].flatten(),
1245
+ self.S2L[0].flatten(),
1246
+ ],
1247
+ 0,
1248
+ ),
1249
+ 0,
1250
+ )
1251
+ for k in range(1, self.P00.shape[0]):
1252
+ tmp = np.concatenate(
1253
+ [
1254
+ tmp,
1255
+ np.expand_dims(
1256
+ np.concatenate(
1257
+ [
1258
+ self.S1[k].flatten(),
1259
+ self.P00[k].flatten(),
1260
+ self.S2[k].flatten(),
1261
+ self.S2L[k].flatten(),
1262
+ ],
1263
+ 0,
1264
+ ),
1265
+ 0,
1266
+ ),
1267
+ ],
1268
+ 0,
1269
+ )
1270
+
1271
+ return np.concatenate([tmp, np.expand_dims(self.S0, 1)], 1)
877
1272
  else:
878
- tmp=self.backend.bk_expand_dims(self.backend.bk_concat([self.backend.bk_flattenR(self.S1[0]),
879
- self.backend.bk_flattenR(self.P00[0]),
880
- self.backend.bk_flattenR(self.S2[0]),
881
- self.backend.bk_flattenR(self.S2[0])],axis=0),0)
882
- for k in range(1,self.P00.shape[0]):
883
- ltmp=self.backend.bk_expand_dims(self.backend.bk_concat([self.backend.bk_flattenR(self.S1[k]),
884
- self.backend.bk_flattenR(self.P00[k]),
885
- self.backend.bk_flattenR(self.S2[k]),
886
- self.backend.bk_flattenR(self.S2[k])],axis=0),0)
887
- tmp=self.backend.bk_concat([tmp,ltmp],0)
888
-
889
- return self.backend.bk_concat([tmp,self.backend.bk_expand_dims(self.S0,1)],1)
890
-
1273
+ tmp = self.backend.bk_expand_dims(
1274
+ self.backend.bk_concat(
1275
+ [
1276
+ self.backend.bk_flattenR(self.S1[0]),
1277
+ self.backend.bk_flattenR(self.P00[0]),
1278
+ self.backend.bk_flattenR(self.S2[0]),
1279
+ self.backend.bk_flattenR(self.S2[0]),
1280
+ ],
1281
+ axis=0,
1282
+ ),
1283
+ 0,
1284
+ )
1285
+ for k in range(1, self.P00.shape[0]):
1286
+ ltmp = self.backend.bk_expand_dims(
1287
+ self.backend.bk_concat(
1288
+ [
1289
+ self.backend.bk_flattenR(self.S1[k]),
1290
+ self.backend.bk_flattenR(self.P00[k]),
1291
+ self.backend.bk_flattenR(self.S2[k]),
1292
+ self.backend.bk_flattenR(self.S2[k]),
1293
+ ],
1294
+ axis=0,
1295
+ ),
1296
+ 0,
1297
+ )
1298
+ tmp = self.backend.bk_concat([tmp, ltmp], 0)
1299
+
1300
+ return self.backend.bk_concat(
1301
+ [tmp, self.backend.bk_expand_dims(self.S0, 1)], 1
1302
+ )
1303
+
891
1304
  # ---------------------------------------------−---------
892
- def model(self,i__y,add=0,dx=3,dell=2,weigth=None,inverse=False):
1305
+ def model(self, i__y, add=0, dx=3, dell=2, weigth=None, inverse=False):
893
1306
 
894
- if i__y.shape[0]<dx+1:
895
- l__dx=i__y.shape[0]-1
1307
+ if i__y.shape[0] < dx + 1:
1308
+ l__dx = i__y.shape[0] - 1
896
1309
  else:
897
- l__dx=dx
1310
+ l__dx = dx
898
1311
 
899
- if i__y.shape[0]<dell:
900
- l__dell=0
1312
+ if i__y.shape[0] < dell:
1313
+ l__dell = 0
901
1314
  else:
902
- l__dell=dell
1315
+ l__dell = dell
903
1316
 
904
- if l__dx<2:
905
- res=np.zeros([i__y.shape[0]+add])
1317
+ if l__dx < 2:
1318
+ res = np.zeros([i__y.shape[0] + add])
906
1319
  if inverse:
907
- res[:-add]=i__y
1320
+ res[:-add] = i__y
908
1321
  else:
909
- res[add:]=i__y[0:]
1322
+ res[add:] = i__y[0:]
910
1323
  return res
911
1324
 
912
1325
  if weigth is None:
913
- w=2**(np.arange(l__dx))
1326
+ w = 2 ** (np.arange(l__dx))
914
1327
  else:
915
1328
  if not inverse:
916
- w=weigth[0:l__dx]
1329
+ w = weigth[0:l__dx]
917
1330
  else:
918
- w=weigth[-l__dx:]
1331
+ w = weigth[-l__dx:]
919
1332
 
920
- x=np.arange(l__dx)+1
1333
+ x = np.arange(l__dx) + 1
921
1334
  if not inverse:
922
- y=np.log(i__y[1:l__dx+1])
1335
+ y = np.log(i__y[1 : l__dx + 1])
923
1336
  else:
924
- y=np.log(i__y[-(l__dx+1):-1])
1337
+ y = np.log(i__y[-(l__dx + 1) : -1])
925
1338
 
926
- r=np.polyfit(x,y,1,w=w)
1339
+ r = np.polyfit(x, y, 1, w=w)
927
1340
 
928
1341
  if inverse:
929
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-1)+r[1])
930
- res[:-(l__dell+add)]=i__y[:-l__dell]
1342
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - 1) + r[1])
1343
+ res[: -(l__dell + add)] = i__y[:-l__dell]
931
1344
  else:
932
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-add)+r[1])
933
- res[l__dell+add:]=i__y[l__dell:]
1345
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - add) + r[1])
1346
+ res[l__dell + add :] = i__y[l__dell:]
934
1347
  return res
935
1348
 
936
- def findn(self,n):
937
- d=np.sqrt(1+8*n)
938
- return int((d-1)/2)
1349
+ def findn(self, n):
1350
+ d = np.sqrt(1 + 8 * n)
1351
+ return int((d - 1) / 2)
939
1352
 
940
- def findidx(self,s2):
941
- i1=np.zeros([s2.shape[1]],dtype='int')
942
- i2=np.zeros([s2.shape[1]],dtype='int')
943
- n=0
1353
+ def findidx(self, s2):
1354
+ i1 = np.zeros([s2.shape[1]], dtype="int")
1355
+ i2 = np.zeros([s2.shape[1]], dtype="int")
1356
+ n = 0
944
1357
  for k in range(self.findn(s2.shape[1])):
945
- i1[n:n+k+1]=np.arange(k+1)
946
- i2[n:n+k+1]=k
947
- n=n+k+1
948
- return i1,i2
949
-
950
- def extrapol_s2(self,add,lnorm=1):
951
- if lnorm==1:
952
- s2=self.S2.numpy()
953
- if lnorm==2:
954
- s2=self.S2L.numpy()
955
- i1,i2=self.findidx(s2)
956
-
957
- so2=np.zeros([s2.shape[0],(self.findn(s2.shape[1])+add)*(self.findn(s2.shape[1])+add+1)//2,s2.shape[2],s2.shape[3]])
958
- oi1,oi2=self.findidx(so2)
959
- for l in range(s2.shape[0]):
1358
+ i1[n : n + k + 1] = np.arange(k + 1)
1359
+ i2[n : n + k + 1] = k
1360
+ n = n + k + 1
1361
+ return i1, i2
1362
+
1363
+ def extrapol_s2(self, add, lnorm=1):
1364
+ if lnorm == 1:
1365
+ s2 = self.S2.numpy()
1366
+ if lnorm == 2:
1367
+ s2 = self.S2L.numpy()
1368
+ i1, i2 = self.findidx(s2)
1369
+
1370
+ so2 = np.zeros(
1371
+ [
1372
+ s2.shape[0],
1373
+ (self.findn(s2.shape[1]) + add)
1374
+ * (self.findn(s2.shape[1]) + add + 1)
1375
+ // 2,
1376
+ s2.shape[2],
1377
+ s2.shape[3],
1378
+ ]
1379
+ )
1380
+ oi1, oi2 = self.findidx(so2)
1381
+ for l_batch in range(s2.shape[0]):
960
1382
  for k in range(self.findn(s2.shape[1])):
961
1383
  for i in range(s2.shape[2]):
962
1384
  for j in range(s2.shape[3]):
963
- tmp=self.model(s2[l,i2==k,i,j],dx=4,dell=1,add=add,weigth=np.array([1,2,2,2]))
964
- tmp[np.isnan(tmp)]=0.0
965
- so2[l,oi2==k+add,i,j]=tmp
966
-
967
-
968
- for l in range(s2.shape[0]):
969
- for k in range(add+1,-1,-1):
970
- lidx=np.where(oi2-oi1==k)[0]
971
- lidx2=np.where(oi2-oi1==k+1)[0]
1385
+ tmp = self.model(
1386
+ s2[l_batch, i2 == k, i, j],
1387
+ dx=4,
1388
+ dell=1,
1389
+ add=add,
1390
+ weigth=np.array([1, 2, 2, 2]),
1391
+ )
1392
+ tmp[np.isnan(tmp)] = 0.0
1393
+ so2[l_batch, oi2 == k + add, i, j] = tmp
1394
+
1395
+ for l_batch in range(s2.shape[0]):
1396
+ for k in range(add + 1, -1, -1):
1397
+ lidx = np.where(oi2 - oi1 == k)[0]
1398
+ lidx2 = np.where(oi2 - oi1 == k + 1)[0]
972
1399
  for i in range(s2.shape[2]):
973
1400
  for j in range(s2.shape[3]):
974
- so2[l,lidx[0:add+2-k],i,j]=so2[l,lidx2[0:add+2-k],i,j]
1401
+ so2[l_batch, lidx[0 : add + 2 - k], i, j] = so2[
1402
+ l_batch, lidx2[0 : add + 2 - k], i, j
1403
+ ]
975
1404
 
976
- return(so2)
1405
+ return so2
977
1406
 
978
- def extrapol_s1(self,i_s1,add):
979
- s1=i_s1.numpy()
980
- so1=np.zeros([s1.shape[0],s1.shape[1]+add,s1.shape[2]])
1407
+ def extrapol_s1(self, i_s1, add):
1408
+ s1 = i_s1.numpy()
1409
+ so1 = np.zeros([s1.shape[0], s1.shape[1] + add, s1.shape[2]])
981
1410
  for k in range(s1.shape[0]):
982
1411
  for i in range(s1.shape[2]):
983
- so1[k,:,i]=self.model(s1[k,:,i],dx=4,dell=1,add=add)
984
- so1[k,np.isnan(so1[k,:,i]),i]=0.0
1412
+ so1[k, :, i] = self.model(s1[k, :, i], dx=4, dell=1, add=add)
1413
+ so1[k, np.isnan(so1[k, :, i]), i] = 0.0
985
1414
  return so1
986
1415
 
987
- def extrapol(self,add):
988
- return scat(self.extrapol_s1(self.P00,add), \
989
- self.S0, \
990
- self.extrapol_s1(self.S1,add), \
991
- self.extrapol_s2(add,lnorm=1), \
992
- self.extrapol_s2(add,lnorm=2),self.j1,self.j2,backend=self.backend)
993
-
994
-
995
-
996
-
997
-
1416
+ def extrapol(self, add):
1417
+ return scat(
1418
+ self.extrapol_s1(self.P00, add),
1419
+ self.S0,
1420
+ self.extrapol_s1(self.S1, add),
1421
+ self.extrapol_s2(add, lnorm=1),
1422
+ self.extrapol_s2(add, lnorm=2),
1423
+ self.j1,
1424
+ self.j2,
1425
+ backend=self.backend,
1426
+ )
1427
+
1428
+
998
1429
  class funct(FOC.FoCUS):
999
-
1000
- def fill(self,im,nullval=hp.UNSEEN):
1001
- return self.fill_healpy(im,nullval=nullval)
1002
-
1003
- def moments(self,list_scat):
1004
- S0=None
1430
+
1431
+ def fill(self, im, nullval=hp.UNSEEN):
1432
+ return self.fill_healpy(im, nullval=nullval)
1433
+
1434
+ def moments(self, list_scat):
1435
+ S0 = None
1005
1436
  for k in list_scat:
1006
- tmp=list_scat[k]
1007
- nS0=np.expand_dims(tmp.S0.numpy(),0)
1008
- nP00=np.expand_dims(tmp.P00.numpy(),0)
1009
- nS1=np.expand_dims(tmp.S1.numpy(),0)
1010
- nS2=np.expand_dims(tmp.S2.numpy(),0)
1011
- nS2L=np.expand_dims(tmp.S2L.numpy(),0)
1012
-
1437
+ tmp = list_scat[k]
1438
+ nS0 = np.expand_dims(tmp.S0.numpy(), 0)
1439
+ nP00 = np.expand_dims(tmp.P00.numpy(), 0)
1440
+ nS1 = np.expand_dims(tmp.S1.numpy(), 0)
1441
+ nS2 = np.expand_dims(tmp.S2.numpy(), 0)
1442
+ nS2L = np.expand_dims(tmp.S2L.numpy(), 0)
1443
+
1013
1444
  if S0 is None:
1014
- S0=nS0
1015
- P00=nP00
1016
- S1=nS1
1017
- S2=nS2
1018
- S2L=nS2L
1445
+ S0 = nS0
1446
+ P00 = nP00
1447
+ S1 = nS1
1448
+ S2 = nS2
1449
+ S2L = nS2L
1019
1450
  else:
1020
- S0=np.concatenate([S0,nS0],0)
1021
- P00=np.concatenate([P00,nP00],0)
1022
- S1=np.concatenate([S1,nS1],0)
1023
- S2=np.concatenate([S2,nS2],0)
1024
- S2L=np.concatenate([S2L,nS2L],0)
1025
-
1026
- sS0=np.std(S0,0)
1027
- sP00=np.std(P00,0)
1028
- sS1=np.std(S1,0)
1029
- sS2=np.std(S2,0)
1030
- sS2L=np.std(S2L,0)
1031
-
1032
- mS0=np.mean(S0,0)
1033
- mP00=np.mean(P00,0)
1034
- mS1=np.mean(S1,0)
1035
- mS2=np.mean(S2,0)
1036
- mS2L=np.mean(S2L,0)
1037
-
1038
- return scat(mP00,mS0,mS1,mS2,mS2L,tmp.j1,tmp.j2,backend=self.backend), \
1039
- scat(sP00,sS0,sS1,sS2,sS2L,tmp.j1,tmp.j2,backend=self.backend)
1040
-
1041
- def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,calc_var=False,norm=None):
1451
+ S0 = np.concatenate([S0, nS0], 0)
1452
+ P00 = np.concatenate([P00, nP00], 0)
1453
+ S1 = np.concatenate([S1, nS1], 0)
1454
+ S2 = np.concatenate([S2, nS2], 0)
1455
+ S2L = np.concatenate([S2L, nS2L], 0)
1456
+
1457
+ sS0 = np.std(S0, 0)
1458
+ sP00 = np.std(P00, 0)
1459
+ sS1 = np.std(S1, 0)
1460
+ sS2 = np.std(S2, 0)
1461
+ sS2L = np.std(S2L, 0)
1462
+
1463
+ mS0 = np.mean(S0, 0)
1464
+ mP00 = np.mean(P00, 0)
1465
+ mS1 = np.mean(S1, 0)
1466
+ mS2 = np.mean(S2, 0)
1467
+ mS2L = np.mean(S2L, 0)
1468
+
1469
+ return scat(
1470
+ mP00, mS0, mS1, mS2, mS2L, tmp.j1, tmp.j2, backend=self.backend
1471
+ ), scat(sP00, sS0, sS1, sS2, sS2L, tmp.j1, tmp.j2, backend=self.backend)
1472
+
1473
+ def eval(
1474
+ self,
1475
+ image1,
1476
+ image2=None,
1477
+ mask=None,
1478
+ Auto=True,
1479
+ s0_off=1e-6,
1480
+ calc_var=False,
1481
+ norm=None,
1482
+ ):
1042
1483
  # Check input consistency
1043
1484
  if image2 is not None:
1044
- if list(image1.shape)!=list(image2.shape):
1045
- print('The two input image should have the same size to eval Scattering')
1046
-
1485
+ if list(image1.shape) != list(image2.shape):
1486
+ print(
1487
+ "The two input image should have the same size to eval Scattering"
1488
+ )
1489
+
1047
1490
  return None
1048
1491
  if mask is not None:
1049
- if list(image1.shape)!=list(mask.shape)[1:]:
1050
- print('The mask should have the same size than the input image to eval Scattering')
1051
- print('Image shape ',image1.shape,'Mask shape ',mask.shape)
1492
+ if list(image1.shape) != list(mask.shape)[1:]:
1493
+ print(
1494
+ "The mask should have the same size than the input image to eval Scattering"
1495
+ )
1496
+ print("Image shape ", image1.shape, "Mask shape ", mask.shape)
1052
1497
  return None
1053
- if self.use_2D and len(image1.shape)<2:
1054
- print('To work with 2D scattering transform, two dimension is needed, input map has only on dimension')
1498
+ if self.use_2D and len(image1.shape) < 2:
1499
+ print(
1500
+ "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
1501
+ )
1055
1502
  return None
1056
-
1057
-
1503
+
1058
1504
  ### AUTO OR CROSS
1059
1505
  cross = False
1060
1506
  if image2 is not None:
1061
1507
  cross = True
1062
- all_cross=not Auto
1063
- else:
1064
- all_cross=False
1065
-
1508
+
1066
1509
  # Check if image1 is [Npix] or [Nbatch,Npix]
1067
- axis=1
1068
-
1510
+ axis = 1
1511
+
1069
1512
  # determine jmax and nside corresponding to the input map
1070
1513
  im_shape = image1.shape
1071
1514
  if self.use_2D:
1072
- if len(image1.shape)==2:
1073
- nside=np.min([im_shape[0],im_shape[1]])
1074
- npix = im_shape[0]*im_shape[1] # Number of pixels
1075
- x1=im_shape[0]
1076
- x2=im_shape[1]
1515
+ if len(image1.shape) == 2:
1516
+ nside = np.min([im_shape[0], im_shape[1]])
1517
+ npix = im_shape[0] * im_shape[1] # Number of pixels
1077
1518
  else:
1078
- nside=np.min([im_shape[1],im_shape[2]])
1079
- npix = im_shape[1]*im_shape[2] # Number of pixels
1080
- x1=im_shape[1]
1081
- x2=im_shape[2]
1082
- jmax = int(np.log(nside-self.KERNELSZ) / np.log(2)) # Number of j scales
1519
+ nside = np.min([im_shape[1], im_shape[2]])
1520
+ npix = im_shape[1] * im_shape[2] # Number of pixels
1521
+ jmax = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
1083
1522
  else:
1084
- if len(image1.shape)==2:
1523
+ if len(image1.shape) == 2:
1085
1524
  npix = int(im_shape[1]) # Number of pixels
1086
1525
  else:
1087
1526
  npix = int(im_shape[0]) # Number of pixels
1088
1527
 
1089
- nside=int(np.sqrt(npix//12))
1090
-
1091
- jmax=int(np.log(nside)/np.log(2)) #-self.OSTEP
1528
+ nside = int(np.sqrt(npix // 12))
1529
+
1530
+ jmax = int(np.log(nside) / np.log(2)) # -self.OSTEP
1092
1531
 
1093
1532
  ### LOCAL VARIABLES (IMAGES and MASK)
1094
1533
  # Check if image1 is [Npix] or [Nbatch,Npix]
1095
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1534
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1096
1535
  # image1 is [Nbatch, Npix]
1097
- I1 = self.backend.bk_cast(self.backend.bk_expand_dims(image1,0)) # Local image1 [Nbatch, Npix]
1536
+ I1 = self.backend.bk_cast(
1537
+ self.backend.bk_expand_dims(image1, 0)
1538
+ ) # Local image1 [Nbatch, Npix]
1098
1539
  if cross:
1099
- I2 = self.backend.bk_cast(self.backend.bk_expand_dims(image2,0)) # Local image2 [Nbatch, Npix]
1540
+ I2 = self.backend.bk_cast(
1541
+ self.backend.bk_expand_dims(image2, 0)
1542
+ ) # Local image2 [Nbatch, Npix]
1100
1543
  else:
1101
- I1=self.backend.bk_cast(image1)
1544
+ I1 = self.backend.bk_cast(image1)
1102
1545
  if cross:
1103
- I2=self.backend.bk_cast(image2)
1104
-
1546
+ I2 = self.backend.bk_cast(image2)
1547
+
1105
1548
  # self.mask is [Nmask, Npix]
1106
1549
  if mask is None:
1107
1550
  if self.use_2D:
1108
- vmask = self.backend.bk_ones([1, I1.shape[axis], I1.shape[axis+1]],dtype=self.all_type)
1551
+ vmask = self.backend.bk_ones(
1552
+ [1, I1.shape[axis], I1.shape[axis + 1]], dtype=self.all_type
1553
+ )
1109
1554
  else:
1110
1555
  vmask = self.backend.bk_ones([1, I1.shape[axis]], dtype=self.all_type)
1111
1556
  else:
1112
1557
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
1113
1558
 
1114
- if self.KERNELSZ>3:
1115
- if self.KERNELSZ==5:
1559
+ if self.KERNELSZ > 3:
1560
+ if self.KERNELSZ == 5:
1116
1561
  # if the kernel size is bigger than 3 increase the binning before smoothing
1117
1562
  if self.use_2D:
1118
- l_image1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1119
- vmask=self.up_grade(vmask,I1.shape[axis]*2,axis=1,nouty=I1.shape[axis+1]*2)
1563
+ l_image1 = self.up_grade(
1564
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
1565
+ )
1566
+ vmask = self.up_grade(
1567
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
1568
+ )
1120
1569
  else:
1121
- l_image1=self.up_grade(I1,nside*2,axis=axis)
1122
- vmask=self.up_grade(vmask,nside*2,axis=1)
1123
-
1570
+ l_image1 = self.up_grade(I1, nside * 2, axis=axis)
1571
+ vmask = self.up_grade(vmask, nside * 2, axis=1)
1572
+
1124
1573
  if cross:
1125
1574
  if self.use_2D:
1126
- l_image2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1575
+ l_image2 = self.up_grade(
1576
+ I2,
1577
+ I2.shape[axis] * 2,
1578
+ axis=axis,
1579
+ nouty=I2.shape[axis + 1] * 2,
1580
+ )
1127
1581
  else:
1128
- l_image2=self.up_grade(I2,nside*2,axis=axis)
1582
+ l_image2 = self.up_grade(I2, nside * 2, axis=axis)
1129
1583
  else:
1130
1584
  # if the kernel size is bigger than 3 increase the binning before smoothing
1131
1585
  if self.use_2D:
1132
- l_image1=self.up_grade(l_image1,I1.shape[axis]*4,axis=axis,nouty=I1.shape[axis+1]*4)
1133
- vmask=self.up_grade(vmask,I1.shape[axis]*4,axis=1,nouty=I1.shape[axis+1]*4)
1586
+ l_image1 = self.up_grade(
1587
+ l_image1,
1588
+ I1.shape[axis] * 4,
1589
+ axis=axis,
1590
+ nouty=I1.shape[axis + 1] * 4,
1591
+ )
1592
+ vmask = self.up_grade(
1593
+ vmask, I1.shape[axis] * 4, axis=1, nouty=I1.shape[axis + 1] * 4
1594
+ )
1134
1595
  else:
1135
- l_image1=self.up_grade(l_image1,nside*4,axis=axis)
1136
- vmask=self.up_grade(vmask,nside*4,axis=1)
1137
-
1596
+ l_image1 = self.up_grade(l_image1, nside * 4, axis=axis)
1597
+ vmask = self.up_grade(vmask, nside * 4, axis=1)
1598
+
1138
1599
  if cross:
1139
1600
  if self.use_2D:
1140
- l_image2=self.up_grade(l_image2,I2.shape[axis]*4,axis=axis,nouty=I2.shape[axis+1]*4)
1601
+ l_image2 = self.up_grade(
1602
+ l_image2,
1603
+ I2.shape[axis] * 4,
1604
+ axis=axis,
1605
+ nouty=I2.shape[axis + 1] * 4,
1606
+ )
1141
1607
  else:
1142
- l_image2=self.up_grade(l_image2,nside*4,axis=axis)
1608
+ l_image2 = self.up_grade(l_image2, nside * 4, axis=axis)
1143
1609
  else:
1144
- l_image1=I1
1610
+ l_image1 = I1
1145
1611
  if cross:
1146
- l_image2=I2
1612
+ l_image2 = I2
1147
1613
 
1148
1614
  if calc_var:
1149
- s0,vs0 = self.masked_mean(l_image1,vmask,axis=axis,calc_var=True)
1150
- s0=s0+s0_off
1615
+ s0, vs0 = self.masked_mean(l_image1, vmask, axis=axis, calc_var=True)
1616
+ s0 = s0 + s0_off
1151
1617
  else:
1152
- s0 = self.masked_mean(l_image1,vmask,axis=axis)+s0_off
1153
-
1154
- if cross and Auto==False:
1618
+ s0 = self.masked_mean(l_image1, vmask, axis=axis) + s0_off
1619
+
1620
+ if cross and not Auto:
1155
1621
  if calc_var:
1156
- s02,vs02=self.masked_mean(l_image2,vmask,axis=axis,calc_var=True)
1622
+ s02, vs02 = self.masked_mean(l_image2, vmask, axis=axis, calc_var=True)
1157
1623
  else:
1158
- s02=self.masked_mean(l_image2,vmask,axis=axis)
1159
-
1160
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1624
+ s02 = self.masked_mean(l_image2, vmask, axis=axis)
1625
+
1626
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1161
1627
  if self.backend.bk_is_complex(s0):
1162
- s0 = self.backend.bk_complex(s0,s02+s0_off)
1628
+ s0 = self.backend.bk_complex(s0, s02 + s0_off)
1163
1629
  if calc_var:
1164
- vs0 = self.backend.bk_complex(vs0,vs02)
1630
+ vs0 = self.backend.bk_complex(vs0, vs02)
1165
1631
  else:
1166
- s0 = self.backend.bk_concat([s0,s02],axis=0)
1632
+ s0 = self.backend.bk_concat([s0, s02], axis=0)
1167
1633
  if calc_var:
1168
- vs0 = self.backend.bk_concat([vs0,vs02],axis=0)
1634
+ vs0 = self.backend.bk_concat([vs0, vs02], axis=0)
1169
1635
  else:
1170
1636
  if self.backend.bk_is_complex(s0):
1171
- s0 = self.backend.bk_complex(s0,s02+s0_off)
1637
+ s0 = self.backend.bk_complex(s0, s02 + s0_off)
1172
1638
  if calc_var:
1173
- vs0 = self.backend.bk_complex(vs0,vs02)
1639
+ vs0 = self.backend.bk_complex(vs0, vs02)
1174
1640
  else:
1175
- s0 = self.backend.bk_concat([s0,s02],axis=0)
1641
+ s0 = self.backend.bk_concat([s0, s02], axis=0)
1176
1642
  if calc_var:
1177
- vs0 = self.backend.bk_concat([vs0,vs02],axis=0)
1178
-
1179
- s1=None
1180
- s2=None
1181
- s2l=None
1182
- p00=None
1183
- s2j1=None
1184
- s2j2=None
1185
-
1186
- l2_image=None
1187
- l2_image_imag=None
1188
-
1643
+ vs0 = self.backend.bk_concat([vs0, vs02], axis=0)
1644
+
1645
+ s1 = None
1646
+ s2 = None
1647
+ s2l = None
1648
+ p00 = None
1649
+ s2j1 = None
1650
+ s2j2 = None
1651
+ l2_image = None
1189
1652
  for j1 in range(jmax):
1190
- if j1<jmax-self.OSTEP: # stop to add scales
1653
+ if j1 < jmax - self.OSTEP: # stop to add scales
1191
1654
  # Convol image along the axis defined by 'axis' using the wavelet defined at
1192
1655
  # the foscat initialisation
1193
- #c_image_real is [....,Npix_j1,....,Norient]
1194
- c_image1=self.convol(l_image1,axis=axis)
1656
+ # c_image_real is [....,Npix_j1,....,Norient]
1657
+ c_image1 = self.convol(l_image1, axis=axis)
1195
1658
  if cross:
1196
- c_image2=self.convol(l_image2,axis=axis)
1659
+ c_image2 = self.convol(l_image2, axis=axis)
1197
1660
  else:
1198
- c_image2=c_image1
1661
+ c_image2 = c_image1
1199
1662
 
1200
1663
  # Compute (a+ib)*(a+ib)* the last c_image column is the real and imaginary part
1201
- conj=c_image1*self.backend.bk_conjugate(c_image2)
1202
-
1664
+ conj = c_image1 * self.backend.bk_conjugate(c_image2)
1665
+
1203
1666
  if Auto:
1204
- conj=self.backend.bk_real(conj)
1667
+ conj = self.backend.bk_real(conj)
1205
1668
 
1206
1669
  # Compute l_p00 [....,....,Nmask,j1,Norient]
1207
1670
  if calc_var:
1208
- l_p00,l_vp00 = self.masked_mean(conj,vmask,axis=axis,rank=j1,calc_var=True)
1209
- l_p00 = self.backend.bk_expand_dims(l_p00,-2)
1210
- l_vp00 = self.backend.bk_expand_dims(l_vp00,-2)
1671
+ l_p00, l_vp00 = self.masked_mean(
1672
+ conj, vmask, axis=axis, rank=j1, calc_var=True
1673
+ )
1674
+ l_p00 = self.backend.bk_expand_dims(l_p00, -2)
1675
+ l_vp00 = self.backend.bk_expand_dims(l_vp00, -2)
1211
1676
  else:
1212
- l_p00 = self.masked_mean(conj,vmask,axis=axis,rank=j1)
1213
- l_p00 = self.backend.bk_expand_dims(l_p00,-2)
1677
+ l_p00 = self.masked_mean(conj, vmask, axis=axis, rank=j1)
1678
+ l_p00 = self.backend.bk_expand_dims(l_p00, -2)
1214
1679
 
1215
- conj=self.backend.bk_L1(conj)
1680
+ conj = self.backend.bk_L1(conj)
1216
1681
 
1217
- # Compute l_s1 [....,....,Nmask,1,Norient]
1682
+ # Compute l_s1 [....,....,Nmask,1,Norient]
1218
1683
  if calc_var:
1219
- l_s1,l_vs1 = self.masked_mean(conj,vmask,axis=axis,rank=j1,calc_var=True)
1220
- l_s1 =self.backend.bk_expand_dims(l_s1,-2)
1221
- l_vs1 =self.backend.bk_expand_dims(l_vs1,-2)
1684
+ l_s1, l_vs1 = self.masked_mean(
1685
+ conj, vmask, axis=axis, rank=j1, calc_var=True
1686
+ )
1687
+ l_s1 = self.backend.bk_expand_dims(l_s1, -2)
1688
+ l_vs1 = self.backend.bk_expand_dims(l_vs1, -2)
1222
1689
  else:
1223
- l_s1 = self.backend.bk_expand_dims(self.masked_mean(conj,vmask,axis=axis,rank=j1),-2)
1690
+ l_s1 = self.backend.bk_expand_dims(
1691
+ self.masked_mean(conj, vmask, axis=axis, rank=j1), -2
1692
+ )
1224
1693
 
1225
- # Concat S1,P00 [....,....,Nmask,j1,Norient]
1694
+ # Concat S1,P00 [....,....,Nmask,j1,Norient]
1226
1695
  if s1 is None:
1227
- s1=l_s1
1228
- p00=l_p00
1696
+ s1 = l_s1
1697
+ p00 = l_p00
1229
1698
  if calc_var:
1230
- vs1=l_vs1
1231
- vp00=l_vp00
1699
+ vs1 = l_vs1
1700
+ vp00 = l_vp00
1232
1701
  else:
1233
- s1=self.backend.bk_concat([s1,l_s1],axis=-2)
1234
- p00=self.backend.bk_concat([p00,l_p00],axis=-2)
1702
+ s1 = self.backend.bk_concat([s1, l_s1], axis=-2)
1703
+ p00 = self.backend.bk_concat([p00, l_p00], axis=-2)
1235
1704
  if calc_var:
1236
- vs1=self.backend.bk_concat([vs1,l_vs1],axis=-2)
1237
- vp00=self.backend.bk_concat([vp00,l_vp00],axis=-2)
1705
+ vs1 = self.backend.bk_concat([vs1, l_vs1], axis=-2)
1706
+ vp00 = self.backend.bk_concat([vp00, l_vp00], axis=-2)
1238
1707
 
1239
1708
  # Concat l2_image [....,j1,Npix_j1,,....,Norient]
1240
1709
  if l2_image is None:
1241
1710
  if self.use_2D:
1242
- l2_image=self.backend.bk_expand_dims(conj,axis=-4)
1711
+ l2_image = self.backend.bk_expand_dims(conj, axis=-4)
1243
1712
  else:
1244
- l2_image=self.backend.bk_expand_dims(conj,axis=-3)
1713
+ l2_image = self.backend.bk_expand_dims(conj, axis=-3)
1245
1714
  else:
1246
1715
  if self.use_2D:
1247
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=-4),l2_image],axis=-4)
1716
+ l2_image = self.backend.bk_concat(
1717
+ [self.backend.bk_expand_dims(conj, axis=-4), l2_image],
1718
+ axis=-4,
1719
+ )
1248
1720
  else:
1249
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=-3),l2_image],axis=-3)
1721
+ l2_image = self.backend.bk_concat(
1722
+ [self.backend.bk_expand_dims(conj, axis=-3), l2_image],
1723
+ axis=-3,
1724
+ )
1250
1725
 
1251
1726
  # Convol l2_image [....,Npix_j1,j1,....,Norient,Norient]
1252
- c2_image=self.convol(self.backend.bk_relu(l2_image),axis=axis+1)
1727
+ c2_image = self.convol(self.backend.bk_relu(l2_image), axis=axis + 1)
1253
1728
 
1254
- conj2p=c2_image*self.backend.bk_conjugate(c2_image)
1255
- conj2pl1=self.backend.bk_L1(conj2p)
1729
+ conj2p = c2_image * self.backend.bk_conjugate(c2_image)
1730
+ conj2pl1 = self.backend.bk_L1(conj2p)
1256
1731
 
1257
1732
  if Auto:
1258
- conj2p=self.backend.bk_real(conj2p)
1259
- conj2pl1=self.backend.bk_real(conj2pl1)
1733
+ conj2p = self.backend.bk_real(conj2p)
1734
+ conj2pl1 = self.backend.bk_real(conj2pl1)
1260
1735
 
1261
- c2_image=self.convol(self.backend.bk_relu(-l2_image),axis=axis+1)
1736
+ c2_image = self.convol(self.backend.bk_relu(-l2_image), axis=axis + 1)
1262
1737
 
1263
- conj2m=c2_image*self.backend.bk_conjugate(c2_image)
1264
- conj2ml1=self.backend.bk_L1(conj2m)
1738
+ conj2m = c2_image * self.backend.bk_conjugate(c2_image)
1739
+ conj2ml1 = self.backend.bk_L1(conj2m)
1265
1740
 
1266
1741
  if Auto:
1267
- conj2m=self.backend.bk_real(conj2m)
1268
- conj2ml1=self.backend.bk_real(conj2ml1)
1269
-
1742
+ conj2m = self.backend.bk_real(conj2m)
1743
+ conj2ml1 = self.backend.bk_real(conj2ml1)
1744
+
1270
1745
  # Convol l_s2 [....,....,Nmask,j1,Norient,Norient]
1271
1746
  if calc_var:
1272
- l_s2,l_vs2 = self.masked_mean(conj2p-conj2m,vmask,axis=axis+1,rank=j1,calc_var=True)
1273
- l_s2l1,l_vs2l1 = self.masked_mean(conj2pl1-conj2ml1,vmask,axis=axis+1,rank=j1,calc_var=True)
1747
+ l_s2, l_vs2 = self.masked_mean(
1748
+ conj2p - conj2m, vmask, axis=axis + 1, rank=j1, calc_var=True
1749
+ )
1750
+ l_s2l1, l_vs2l1 = self.masked_mean(
1751
+ conj2pl1 - conj2ml1, vmask, axis=axis + 1, rank=j1, calc_var=True
1752
+ )
1274
1753
  else:
1275
- l_s2 = self.masked_mean(conj2p-conj2m,vmask,axis=axis+1,rank=j1)
1276
- l_s2l1 = self.masked_mean(conj2pl1-conj2ml1,vmask,axis=axis+1,rank=j1)
1754
+ l_s2 = self.masked_mean(conj2p - conj2m, vmask, axis=axis + 1, rank=j1)
1755
+ l_s2l1 = self.masked_mean(
1756
+ conj2pl1 - conj2ml1, vmask, axis=axis + 1, rank=j1
1757
+ )
1277
1758
 
1278
1759
  # Concat l_s2 [....,....,Nmask,j1*(j1+1)/2,Norient,Norient]
1279
1760
  if s2 is None:
1280
- s2l=l_s2
1281
- s2=l_s2l1
1761
+ s2l = l_s2
1762
+ s2 = l_s2l1
1282
1763
  if calc_var:
1283
- vs2l=l_vs2
1284
- vs2=l_vs2l1
1285
-
1286
- s2j1=np.arange(l_s2.shape[axis+1],dtype='int')
1287
- s2j2=j1*np.ones(l_s2.shape[axis+1],dtype='int')
1764
+ vs2l = l_vs2
1765
+ vs2 = l_vs2l1
1766
+
1767
+ s2j1 = np.arange(l_s2.shape[axis + 1], dtype="int")
1768
+ s2j2 = j1 * np.ones(l_s2.shape[axis + 1], dtype="int")
1288
1769
  else:
1289
- s2=self.backend.bk_concat([s2,l_s2l1],axis=-3)
1290
- s2l=self.backend.bk_concat([s2l,l_s2],axis=-3)
1770
+ s2 = self.backend.bk_concat([s2, l_s2l1], axis=-3)
1771
+ s2l = self.backend.bk_concat([s2l, l_s2], axis=-3)
1291
1772
  if calc_var:
1292
- vs2=self.backend.bk_concat([vs2,l_vs2l1],axis=-3)
1293
- vs2l=self.backend.bk_concat([vs2l,l_vs2],axis=-3)
1294
-
1295
- s2j1=np.concatenate([s2j1,np.arange(l_s2.shape[axis+1],dtype='int')],0)
1296
- s2j2=np.concatenate([s2j2,j1*np.ones(l_s2.shape[axis+1],dtype='int')],0)
1297
-
1298
- if j1!=jmax-1:
1299
- # Rescale vmask [Nmask,Npix_j1//4]
1300
- vmask = self.smooth(vmask,axis=1)
1301
- vmask = self.ud_grade_2(vmask,axis=1)
1773
+ vs2 = self.backend.bk_concat([vs2, l_vs2l1], axis=-3)
1774
+ vs2l = self.backend.bk_concat([vs2l, l_vs2], axis=-3)
1775
+
1776
+ s2j1 = np.concatenate(
1777
+ [s2j1, np.arange(l_s2.shape[axis + 1], dtype="int")], 0
1778
+ )
1779
+ s2j2 = np.concatenate(
1780
+ [s2j2, j1 * np.ones(l_s2.shape[axis + 1], dtype="int")], 0
1781
+ )
1782
+
1783
+ if j1 != jmax - 1:
1784
+ # Rescale vmask [Nmask,Npix_j1//4]
1785
+ vmask = self.smooth(vmask, axis=1)
1786
+ vmask = self.ud_grade_2(vmask, axis=1)
1302
1787
  if self.mask_thres is not None:
1303
- vmask = self.backend.bk_threshold(vmask,self.mask_thres)
1788
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
1304
1789
 
1305
- # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1306
- l2_image = self.smooth(l2_image,axis=axis+1)
1307
- l2_image = self.ud_grade_2(l2_image,axis=axis+1)
1790
+ # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1791
+ l2_image = self.smooth(l2_image, axis=axis + 1)
1792
+ l2_image = self.ud_grade_2(l2_image, axis=axis + 1)
1308
1793
 
1309
- # Rescale l_image [....,Npix_j1//4,....]
1310
- l_image1 = self.smooth(l_image1,axis=axis)
1311
- l_image1 = self.ud_grade_2(l_image1,axis=axis)
1794
+ # Rescale l_image [....,Npix_j1//4,....]
1795
+ l_image1 = self.smooth(l_image1, axis=axis)
1796
+ l_image1 = self.ud_grade_2(l_image1, axis=axis)
1312
1797
  if cross:
1313
- l_image2 = self.smooth(l_image2,axis=axis)
1314
- l_image2 = self.ud_grade_2(l_image2,axis=axis)
1315
-
1316
-
1317
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1318
- sc_ret=scat(p00[0],s0[0],s1[0],s2[0],s2l[0],s2j1,s2j2,cross=cross,backend=self.backend)
1798
+ l_image2 = self.smooth(l_image2, axis=axis)
1799
+ l_image2 = self.ud_grade_2(l_image2, axis=axis)
1800
+
1801
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1802
+ sc_ret = scat(
1803
+ p00[0],
1804
+ s0[0],
1805
+ s1[0],
1806
+ s2[0],
1807
+ s2l[0],
1808
+ s2j1,
1809
+ s2j2,
1810
+ cross=cross,
1811
+ backend=self.backend,
1812
+ )
1319
1813
  else:
1320
- sc_ret=scat(p00,s0,s1,s2,s2l,s2j1,s2j2,cross=cross,backend=self.backend)
1321
-
1814
+ sc_ret = scat(
1815
+ p00, s0, s1, s2, s2l, s2j1, s2j2, cross=cross, backend=self.backend
1816
+ )
1817
+
1322
1818
  if calc_var:
1323
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1324
- vsc_ret=scat(vp00[0],vs0[0],vs1[0],vs2[0],vs2l[0],s2j1,s2j2,cross=cross,backend=self.backend)
1819
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1820
+ vsc_ret = scat(
1821
+ vp00[0],
1822
+ vs0[0],
1823
+ vs1[0],
1824
+ vs2[0],
1825
+ vs2l[0],
1826
+ s2j1,
1827
+ s2j2,
1828
+ cross=cross,
1829
+ backend=self.backend,
1830
+ )
1325
1831
  else:
1326
- vsc_ret=scat(vp00,vs0,vs1,vs2,vs2l,s2j1,s2j2,cross=cross,backend=self.backend)
1327
- return sc_ret,vsc_ret
1832
+ vsc_ret = scat(
1833
+ vp00,
1834
+ vs0,
1835
+ vs1,
1836
+ vs2,
1837
+ vs2l,
1838
+ s2j1,
1839
+ s2j2,
1840
+ cross=cross,
1841
+ backend=self.backend,
1842
+ )
1843
+ return sc_ret, vsc_ret
1328
1844
  else:
1329
1845
  return sc_ret
1330
1846
 
1331
- def square(self,x):
1847
+ def square(self, x):
1332
1848
  # the abs make the complex value usable for reduce_sum or mean
1333
- return scat(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1334
- self.backend.bk_square(self.backend.bk_abs(x.S0)),
1335
- self.backend.bk_square(self.backend.bk_abs(x.S1)),
1336
- self.backend.bk_square(self.backend.bk_abs(x.S2)),
1337
- self.backend.bk_square(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1338
-
1339
- def sqrt(self,x):
1849
+ return scat(
1850
+ self.backend.bk_square(self.backend.bk_abs(x.P00)),
1851
+ self.backend.bk_square(self.backend.bk_abs(x.S0)),
1852
+ self.backend.bk_square(self.backend.bk_abs(x.S1)),
1853
+ self.backend.bk_square(self.backend.bk_abs(x.S2)),
1854
+ self.backend.bk_square(self.backend.bk_abs(x.S2L)),
1855
+ x.j1,
1856
+ x.j2,
1857
+ backend=self.backend,
1858
+ )
1859
+
1860
+ def sqrt(self, x):
1340
1861
  # the abs make the complex value usable for reduce_sum or mean
1341
- return scat(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1342
- self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1343
- self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1344
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1345
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1346
-
1347
- def reduce_distance(self, x,y, sigma=None):
1348
-
1862
+ return scat(
1863
+ self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1864
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1865
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1866
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1867
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),
1868
+ x.j1,
1869
+ x.j2,
1870
+ backend=self.backend,
1871
+ )
1872
+
1873
+ def reduce_distance(self, x, y, sigma=None):
1874
+
1349
1875
  if isinstance(x, scat):
1350
1876
  if sigma is None:
1351
- result=self.diff_data(y.S0,x.S0,is_complex=False)
1352
- result+=self.diff_data(y.S1,x.S1)
1353
- result+=self.diff_data(y.P00,x.P00)
1354
- result+=self.diff_data(y.S2,x.S2)
1355
- result+=self.diff_data(y.S2L,x.S2L)
1877
+ result = self.diff_data(y.S0, x.S0, is_complex=False)
1878
+ result += self.diff_data(y.S1, x.S1)
1879
+ result += self.diff_data(y.P00, x.P00)
1880
+ result += self.diff_data(y.S2, x.S2)
1881
+ result += self.diff_data(y.S2L, x.S2L)
1356
1882
  else:
1357
- result=self.diff_data(y.S0,x.S0,is_complex=False,sigma=sigma.S0)
1358
- result+=self.diff_data(y.S1,x.S1,sigma=sigma.S1)
1359
- result+=self.diff_data(y.P00,x.P00,sigma=sigma.P00)
1360
- result+=self.diff_data(y.S2,x.S2,sigma=sigma.S2)
1361
- result+=self.diff_data(y.S2L,x.S2L,sigma=sigma.S2L)
1362
-
1363
- nval=self.backend.bk_size(x.S0)+self.backend.bk_size(x.P00)+ \
1364
- self.backend.bk_size(x.S1)+self.backend.bk_size(x.S2)+self.backend.bk_size(x.S2L)
1365
-
1366
- result/=self.backend.bk_cast(nval)
1883
+ result = self.diff_data(y.S0, x.S0, is_complex=False, sigma=sigma.S0)
1884
+ result += self.diff_data(y.S1, x.S1, sigma=sigma.S1)
1885
+ result += self.diff_data(y.P00, x.P00, sigma=sigma.P00)
1886
+ result += self.diff_data(y.S2, x.S2, sigma=sigma.S2)
1887
+ result += self.diff_data(y.S2L, x.S2L, sigma=sigma.S2L)
1888
+
1889
+ nval = (
1890
+ self.backend.bk_size(x.S0)
1891
+ + self.backend.bk_size(x.P00)
1892
+ + self.backend.bk_size(x.S1)
1893
+ + self.backend.bk_size(x.S2)
1894
+ + self.backend.bk_size(x.S2L)
1895
+ )
1896
+
1897
+ result /= self.backend.bk_cast(nval)
1367
1898
  else:
1368
1899
  return self.backend.bk_reduce_sum(x)
1369
1900
  return result
1370
1901
 
1371
- def reduce_mean(self,x,axis=None):
1902
+ def reduce_mean(self, x, axis=None):
1372
1903
  if axis is None:
1373
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))+ \
1374
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))+ \
1375
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))+ \
1376
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))+ \
1377
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1378
-
1379
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1380
- np.array(list(x.S0.shape)).prod()+ \
1381
- np.array(list(x.S1.shape)).prod()+ \
1382
- np.array(list(x.S2.shape)).prod()
1383
-
1384
- return tmp/ntmp
1904
+ tmp = (
1905
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))
1906
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))
1907
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))
1908
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))
1909
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1910
+ )
1911
+
1912
+ ntmp = (
1913
+ np.array(list(x.P00.shape)).prod()
1914
+ + np.array(list(x.S0.shape)).prod()
1915
+ + np.array(list(x.S1.shape)).prod()
1916
+ + np.array(list(x.S2.shape)).prod()
1917
+ )
1918
+
1919
+ return tmp / ntmp
1385
1920
  else:
1386
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00,axis=axis))+ \
1387
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0,axis=axis))+ \
1388
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1,axis=axis))+ \
1389
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2,axis=axis))+ \
1390
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L,axis=axis))
1391
-
1392
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1393
- np.array(list(x.S0.shape)).prod()+ \
1394
- np.array(list(x.S1.shape)).prod()+ \
1395
- np.array(list(x.S2.shape)).prod()+ \
1396
- np.array(list(x.S2L.shape)).prod()
1397
-
1398
- return tmp/ntmp
1399
-
1400
- def reduce_sum(self,x,axis=None):
1921
+ tmp = (
1922
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00, axis=axis))
1923
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0, axis=axis))
1924
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1, axis=axis))
1925
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2, axis=axis))
1926
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L, axis=axis))
1927
+ )
1928
+
1929
+ ntmp = (
1930
+ np.array(list(x.P00.shape)).prod()
1931
+ + np.array(list(x.S0.shape)).prod()
1932
+ + np.array(list(x.S1.shape)).prod()
1933
+ + np.array(list(x.S2.shape)).prod()
1934
+ + np.array(list(x.S2L.shape)).prod()
1935
+ )
1936
+
1937
+ return tmp / ntmp
1938
+
1939
+ def reduce_sum(self, x, axis=None):
1401
1940
  if axis is None:
1402
- return self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))+ \
1403
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))+ \
1404
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))+ \
1405
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))+ \
1406
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1941
+ return (
1942
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))
1943
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
1944
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
1945
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
1946
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1947
+ )
1407
1948
  else:
1408
- return scat(self.backend.bk_reduce_sum(x.P00,axis=axis),
1409
- self.backend.bk_reduce_sum(x.S0,axis=axis),
1410
- self.backend.bk_reduce_sum(x.S1,axis=axis),
1411
- self.backend.bk_reduce_sum(x.S2,axis=axis),
1412
- self.backend.bk_reduce_sum(x.S2L,axis=axis),x.j1,x.j2,backend=self.backend)
1413
-
1414
- def ldiff(self,sig,x):
1415
- return scat(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1416
- x.domult(sig.S0,x.S0)*x.domult(sig.S0,x.S0),
1417
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1418
- x.domult(sig.S2,x.S2)*x.domult(sig.S2,x.S2),
1419
- x.domult(sig.S2L,x.S2L)*x.domult(sig.S2L,x.S2L),x.j1,x.j2,backend=self.backend)
1420
-
1421
- def log(self,x):
1422
- return scat(self.backend.bk_log(x.P00),
1423
- self.backend.bk_log(x.S0),
1424
- self.backend.bk_log(x.S1),
1425
- self.backend.bk_log(x.S2),
1426
- self.backend.bk_log(x.S2L),x.j1,x.j2,backend=self.backend)
1427
- def abs(self,x):
1428
- return scat(self.backend.bk_abs(x.P00),
1429
- self.backend.bk_abs(x.S0),
1430
- self.backend.bk_abs(x.S1),
1431
- self.backend.bk_abs(x.S2),
1432
- self.backend.bk_abs(x.S2L),x.j1,x.j2,backend=self.backend)
1433
- def inv(self,x):
1434
- return scat(1/(x.P00),1/(x.S0),1/(x.S1),1/(x.S2),1/(x.S2L),x.j1,x.j2,backend=self.backend)
1949
+ return scat(
1950
+ self.backend.bk_reduce_sum(x.P00, axis=axis),
1951
+ self.backend.bk_reduce_sum(x.S0, axis=axis),
1952
+ self.backend.bk_reduce_sum(x.S1, axis=axis),
1953
+ self.backend.bk_reduce_sum(x.S2, axis=axis),
1954
+ self.backend.bk_reduce_sum(x.S2L, axis=axis),
1955
+ x.j1,
1956
+ x.j2,
1957
+ backend=self.backend,
1958
+ )
1959
+
1960
+ def ldiff(self, sig, x):
1961
+ return scat(
1962
+ x.domult(sig.P00, x.P00) * x.domult(sig.P00, x.P00),
1963
+ x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
1964
+ x.domult(sig.S1, x.S1) * x.domult(sig.S1, x.S1),
1965
+ x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
1966
+ x.domult(sig.S2L, x.S2L) * x.domult(sig.S2L, x.S2L),
1967
+ x.j1,
1968
+ x.j2,
1969
+ backend=self.backend,
1970
+ )
1971
+
1972
+ def log(self, x):
1973
+ return scat(
1974
+ self.backend.bk_log(x.P00),
1975
+ self.backend.bk_log(x.S0),
1976
+ self.backend.bk_log(x.S1),
1977
+ self.backend.bk_log(x.S2),
1978
+ self.backend.bk_log(x.S2L),
1979
+ x.j1,
1980
+ x.j2,
1981
+ backend=self.backend,
1982
+ )
1983
+
1984
+ def abs(self, x):
1985
+ return scat(
1986
+ self.backend.bk_abs(x.P00),
1987
+ self.backend.bk_abs(x.S0),
1988
+ self.backend.bk_abs(x.S1),
1989
+ self.backend.bk_abs(x.S2),
1990
+ self.backend.bk_abs(x.S2L),
1991
+ x.j1,
1992
+ x.j2,
1993
+ backend=self.backend,
1994
+ )
1995
+
1996
+ def inv(self, x):
1997
+ return scat(
1998
+ 1 / (x.P00),
1999
+ 1 / (x.S0),
2000
+ 1 / (x.S1),
2001
+ 1 / (x.S2),
2002
+ 1 / (x.S2L),
2003
+ x.j1,
2004
+ x.j2,
2005
+ backend=self.backend,
2006
+ )
1435
2007
 
1436
2008
  def one(self):
1437
- return scat(1.0,1.0,1.0,1.0,1.0,[0],[0],backend=self.backend)
2009
+ return scat(1.0, 1.0, 1.0, 1.0, 1.0, [0], [0], backend=self.backend)
1438
2010
 
1439
2011
  @tf_function
1440
- def eval_comp_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
2012
+ def eval_comp_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1441
2013
 
1442
- res=self.eval(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1443
- return res.P00,res.S0,res.S1,res.S2,res.S2L,res.j1,res.j2
2014
+ res = self.eval(image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off)
2015
+ return res.P00, res.S0, res.S1, res.S2, res.S2L, res.j1, res.j2
1444
2016
 
1445
- def eval_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1446
- p0,s0,s1,s2,s2l,j1,j2=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1447
- return scat(p0,s0,s1,s2,s2l,j1,j2,backend=self.backend)
1448
-
1449
-
1450
-
2017
+ def eval_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
2018
+ p0, s0, s1, s2, s2l, j1, j2 = self.eval_comp_fast(
2019
+ image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off
2020
+ )
2021
+ return scat(p0, s0, s1, s2, s2l, j1, j2, backend=self.backend)