foscat 3.0.9__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat.py CHANGED
@@ -1,412 +1,602 @@
1
- import foscat.FoCUS as FOC
2
- import numpy as np
3
- import tensorflow as tf
4
1
  import pickle
5
- import foscat.backend as bk
2
+ import sys
3
+
6
4
  import healpy as hp
7
-
5
+ import numpy as np
6
+
7
+ import foscat.backend as bk
8
+ import foscat.FoCUS as FOC
9
+
10
+ # Vérifier si TensorFlow est importé et défini
11
+ tf_defined = "tensorflow" in sys.modules
12
+
13
+ if tf_defined:
14
+ import tensorflow as tf
15
+
16
+ tf_function = (
17
+ tf.function
18
+ ) # Facultatif : si vous voulez utiliser TensorFlow dans ce script
19
+ else:
20
+
21
+ def tf_function(func):
22
+ return func
23
+
24
+
8
25
  def read(filename):
9
- thescat=scat(1,1,1,1,1,[0],[0])
26
+ thescat = scat(1, 1, 1, 1, 1, [0], [0])
10
27
  return thescat.read(filename)
11
-
28
+
29
+
12
30
  class scat:
13
- def __init__(self,p00,s0,s1,s2,s2l,j1,j2,cross=False,backend=None):
14
- self.bk_type='SCAT'
15
- self.P00=p00
16
- self.S0=s0
17
- self.S1=s1
18
- self.S2=s2
19
- self.S2L=s2l
20
- self.j1=j1
21
- self.j2=j2
22
- self.cross=cross
23
- self.backend=backend
24
-
25
- def set_bk_type(self,bk_type):
26
- self.bk_type=bk_type
27
-
31
+ def __init__(self, p00, s0, s1, s2, s2l, j1, j2, cross=False, backend=None):
32
+ self.bk_type = "SCAT"
33
+ self.P00 = p00
34
+ self.S0 = s0
35
+ self.S1 = s1
36
+ self.S2 = s2
37
+ self.S2L = s2l
38
+ self.j1 = j1
39
+ self.j2 = j2
40
+ self.cross = cross
41
+ self.backend = backend
42
+
43
+ def set_bk_type(self, bk_type):
44
+ self.bk_type = bk_type
45
+
28
46
  def get_j_idx(self):
29
- return self.j1,self.j2
30
-
47
+ return self.j1, self.j2
48
+
31
49
  def get_S0(self):
32
- return(self.S0)
50
+ return self.S0
33
51
 
34
52
  def get_S1(self):
35
- return(self.S1)
36
-
53
+ return self.S1
54
+
37
55
  def get_S2(self):
38
- return(self.S2)
56
+ return self.S2
39
57
 
40
58
  def get_S2L(self):
41
- return(self.S2L)
59
+ return self.S2L
42
60
 
43
61
  def get_P00(self):
44
- return(self.P00)
62
+ return self.P00
45
63
 
46
64
  def reset_P00(self):
47
- self.P00=0*self.P00
65
+ self.P00 = 0 * self.P00
48
66
 
49
67
  def constant(self):
50
- return scat(self.backend.constant(self.P00 ), \
51
- self.backend.constant(self.S0 ), \
52
- self.backend.constant(self.S1 ), \
53
- self.backend.constant(self.S2 ), \
54
- self.backend.constant(self.S2L ), \
55
- self.j1 , \
56
- self.j2 ,backend=self.backend)
57
-
58
- def domult(self,x,y):
59
- if x.dtype==y.dtype:
60
- return x*y
61
-
62
- if self.backend.bk_is_complex(x):
63
-
64
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
65
- else:
66
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
67
-
68
- def dodiv(self,x,y):
69
- if x.dtype==y.dtype:
70
- return x/y
71
- if self.backend.bk_is_complex(x):
72
-
73
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
74
- else:
75
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
76
-
77
- def domin(self,x,y):
78
- if x.dtype==y.dtype:
79
- return x-y
80
-
81
- if self.backend.bk_is_complex(x):
82
-
83
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
84
- else:
85
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
86
-
87
- def doadd(self,x,y):
88
- if x.dtype==y.dtype:
89
- return x+y
90
-
91
- if self.backend.bk_is_complex(x):
92
-
93
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
94
- else:
95
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
96
-
97
- def relu(self):
98
-
99
- return scat(self.backend.bk_relu(self.P00), \
100
- self.backend.bk_relu(self.S0), \
101
- self.backend.bk_relu(self.S1), \
102
- self.backend.bk_relu(self.S2), \
103
- self.backend.bk_relu(self.S2L), \
104
- self.j1,self.j2,backend=self.backend)
105
-
106
- def __add__(self,other):
107
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
108
- isinstance(other, bool) or isinstance(other, scat)
109
-
68
+ return scat(
69
+ self.backend.constant(self.P00),
70
+ self.backend.constant(self.S0),
71
+ self.backend.constant(self.S1),
72
+ self.backend.constant(self.S2),
73
+ self.backend.constant(self.S2L),
74
+ self.j1,
75
+ self.j2,
76
+ backend=self.backend,
77
+ )
78
+
79
+ def domult(self, x, y):
80
+ try:
81
+ return x * y
82
+ except:
83
+ if x.dtype == y.dtype:
84
+ return x * y
85
+ if self.backend.bk_is_complex(x):
86
+
87
+ return self.backend.bk_complex(
88
+ self.backend.bk_real(x) * y, self.backend.bk_imag(x) * y
89
+ )
90
+ else:
91
+ return self.backend.bk_complex(
92
+ self.backend.bk_real(y) * x, self.backend.bk_imag(y) * x
93
+ )
94
+
95
+ def dodiv(self, x, y):
96
+ try:
97
+ return x / y
98
+ except:
99
+ if x.dtype == y.dtype:
100
+ return x / y
101
+ if self.backend.bk_is_complex(x):
102
+
103
+ return self.backend.bk_complex(
104
+ self.backend.bk_real(x) / y, self.backend.bk_imag(x) / y
105
+ )
106
+ else:
107
+ return self.backend.bk_complex(
108
+ x / self.backend.bk_real(y), x / self.backend.bk_imag(y)
109
+ )
110
+
111
+ def domin(self, x, y):
112
+ try:
113
+ return x - y
114
+ except:
115
+ if x.dtype == y.dtype:
116
+ return x - y
117
+
118
+ if self.backend.bk_is_complex(x):
119
+
120
+ return self.backend.bk_complex(
121
+ self.backend.bk_real(x) - y, self.backend.bk_imag(x) - y
122
+ )
123
+ else:
124
+ return self.backend.bk_complex(
125
+ x - self.backend.bk_real(y), x - self.backend.bk_imag(y)
126
+ )
127
+
128
+ def doadd(self, x, y):
129
+ try:
130
+ return x + y
131
+ except:
132
+ if x.dtype == y.dtype:
133
+ return x + y
134
+ if self.backend.bk_is_complex(x):
135
+
136
+ return self.backend.bk_complex(
137
+ self.backend.bk_real(x) + y, self.backend.bk_imag(x) + y
138
+ )
139
+ else:
140
+ return self.backend.bk_complex(
141
+ x + self.backend.bk_real(y), x + self.backend.bk_imag(y)
142
+ )
143
+
144
+ def __add__(self, other):
145
+ assert (
146
+ isinstance(other, float)
147
+ or isinstance(other, np.float32)
148
+ or isinstance(other, int)
149
+ or isinstance(other, bool)
150
+ or isinstance(other, scat)
151
+ )
152
+
110
153
  if isinstance(other, scat):
111
- return scat(self.doadd(self.P00,other.P00), \
112
- self.doadd(self.S0, other.S0), \
113
- self.doadd(self.S1, other.S1), \
114
- self.doadd(self.S2, other.S2), \
115
- self.doadd(self.S2L, other.S2L), \
116
- self.j1,self.j2,backend=self.backend)
154
+ return scat(
155
+ self.doadd(self.P00, other.P00),
156
+ self.doadd(self.S0, other.S0),
157
+ self.doadd(self.S1, other.S1),
158
+ self.doadd(self.S2, other.S2),
159
+ self.doadd(self.S2L, other.S2L),
160
+ self.j1,
161
+ self.j2,
162
+ backend=self.backend,
163
+ )
117
164
  else:
118
- return scat((self.P00+ other), \
119
- (self.S0+ other), \
120
- (self.S1+ other), \
121
- (self.S2+ other), \
122
- (self.S2L+ other), \
123
- self.j1,self.j2,backend=self.backend)
124
-
125
- def toreal(self,value):
165
+ return scat(
166
+ (self.P00 + other),
167
+ (self.S0 + other),
168
+ (self.S1 + other),
169
+ (self.S2 + other),
170
+ (self.S2L + other),
171
+ self.j1,
172
+ self.j2,
173
+ backend=self.backend,
174
+ )
175
+
176
+ def toreal(self, value):
126
177
  if value is None:
127
178
  return None
128
-
179
+
129
180
  return self.backend.bk_real(value)
130
181
 
131
- def addcomplex(self,value,amp):
182
+ def addcomplex(self, value, amp):
132
183
  if value is None:
133
184
  return None
134
-
135
- return self.backend.bk_complex(value,amp*value)
136
-
137
- def add_complex(self,amp):
138
- return scat(self.addcomplex(self.P00,amp), \
139
- self.addcomplex(self.S0,amp), \
140
- self.addcomplex(self.S1,amp), \
141
- self.addcomplex(self.S2,amp), \
142
- self.addcomplex(self.S2L,amp), \
143
- self.j1,self.j2,backend=self.backend)
144
-
185
+
186
+ return self.backend.bk_complex(value, amp * value)
187
+
188
+ def add_complex(self, amp):
189
+ return scat(
190
+ self.addcomplex(self.P00, amp),
191
+ self.addcomplex(self.S0, amp),
192
+ self.addcomplex(self.S1, amp),
193
+ self.addcomplex(self.S2, amp),
194
+ self.addcomplex(self.S2L, amp),
195
+ self.j1,
196
+ self.j2,
197
+ backend=self.backend,
198
+ )
199
+
145
200
  def real(self):
146
- return scat(self.toreal(self.P00), \
147
- self.toreal(self.S0), \
148
- self.toreal(self.S1), \
149
- self.toreal(self.S2), \
150
- self.toreal(self.S2L), \
151
- self.j1,self.j2,backend=self.backend)
152
-
153
- def __radd__(self,other):
201
+ return scat(
202
+ self.toreal(self.P00),
203
+ self.toreal(self.S0),
204
+ self.toreal(self.S1),
205
+ self.toreal(self.S2),
206
+ self.toreal(self.S2L),
207
+ self.j1,
208
+ self.j2,
209
+ backend=self.backend,
210
+ )
211
+
212
+ def __radd__(self, other):
154
213
  return self.__add__(other)
155
214
 
156
- def __truediv__(self,other):
157
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
158
- isinstance(other, bool) or isinstance(other, scat)
159
-
215
+ def __truediv__(self, other):
216
+ assert (
217
+ isinstance(other, float)
218
+ or isinstance(other, np.float32)
219
+ or isinstance(other, int)
220
+ or isinstance(other, bool)
221
+ or isinstance(other, scat)
222
+ )
223
+
160
224
  if isinstance(other, scat):
161
- return scat(self.dodiv(self.P00, other.P00), \
162
- self.dodiv(self.S0, other.S0), \
163
- self.dodiv(self.S1, other.S1), \
164
- self.dodiv(self.S2, other.S2), \
165
- self.dodiv(self.S2L, other.S2L), \
166
- self.j1,self.j2,backend=self.backend)
225
+ return scat(
226
+ self.dodiv(self.P00, other.P00),
227
+ self.dodiv(self.S0, other.S0),
228
+ self.dodiv(self.S1, other.S1),
229
+ self.dodiv(self.S2, other.S2),
230
+ self.dodiv(self.S2L, other.S2L),
231
+ self.j1,
232
+ self.j2,
233
+ backend=self.backend,
234
+ )
167
235
  else:
168
- return scat((self.P00/ other), \
169
- (self.S0/ other), \
170
- (self.S1/ other), \
171
- (self.S2/ other), \
172
- (self.S2L/ other), \
173
- self.j1,self.j2,backend=self.backend)
174
-
175
-
176
- def __rtruediv__(self,other):
177
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
178
- isinstance(other, bool) or isinstance(other, scat)
179
-
236
+ return scat(
237
+ (self.P00 / other),
238
+ (self.S0 / other),
239
+ (self.S1 / other),
240
+ (self.S2 / other),
241
+ (self.S2L / other),
242
+ self.j1,
243
+ self.j2,
244
+ backend=self.backend,
245
+ )
246
+
247
+ def __rtruediv__(self, other):
248
+ assert (
249
+ isinstance(other, float)
250
+ or isinstance(other, np.float32)
251
+ or isinstance(other, int)
252
+ or isinstance(other, bool)
253
+ or isinstance(other, scat)
254
+ )
255
+
180
256
  if isinstance(other, scat):
181
- return scat(self.dodiv(other.P00, self.P00), \
182
- self.dodiv(other.S0 , self.S0), \
183
- self.dodiv(other.S1 , self.S1), \
184
- self.dodiv(other.S2 , self.S2), \
185
- self.dodiv(other.S2L , self.S2L), \
186
- self.j1,self.j2,backend=self.backend)
257
+ return scat(
258
+ self.dodiv(other.P00, self.P00),
259
+ self.dodiv(other.S0, self.S0),
260
+ self.dodiv(other.S1, self.S1),
261
+ self.dodiv(other.S2, self.S2),
262
+ self.dodiv(other.S2L, self.S2L),
263
+ self.j1,
264
+ self.j2,
265
+ backend=self.backend,
266
+ )
187
267
  else:
188
- return scat((other/ self.P00), \
189
- (other / self.S0), \
190
- (other / self.S1), \
191
- (other / self.S2), \
192
- (other / self.S2L), \
193
- self.j1,self.j2,backend=self.backend)
194
-
195
- def __sub__(self,other):
196
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
197
- isinstance(other, bool) or isinstance(other, scat)
198
-
268
+ return scat(
269
+ (other / self.P00),
270
+ (other / self.S0),
271
+ (other / self.S1),
272
+ (other / self.S2),
273
+ (other / self.S2L),
274
+ self.j1,
275
+ self.j2,
276
+ backend=self.backend,
277
+ )
278
+
279
+ def __sub__(self, other):
280
+ assert (
281
+ isinstance(other, float)
282
+ or isinstance(other, np.float32)
283
+ or isinstance(other, int)
284
+ or isinstance(other, bool)
285
+ or isinstance(other, scat)
286
+ )
287
+
199
288
  if isinstance(other, scat):
200
- return scat(self.domin(self.P00, other.P00), \
201
- self.domin(self.S0, other.S0), \
202
- self.domin(self.S1, other.S1), \
203
- self.domin(self.S2, other.S2), \
204
- self.domin(self.S2L, other.S2L), \
205
- self.j1,self.j2,backend=self.backend)
289
+ return scat(
290
+ self.domin(self.P00, other.P00),
291
+ self.domin(self.S0, other.S0),
292
+ self.domin(self.S1, other.S1),
293
+ self.domin(self.S2, other.S2),
294
+ self.domin(self.S2L, other.S2L),
295
+ self.j1,
296
+ self.j2,
297
+ backend=self.backend,
298
+ )
206
299
  else:
207
- return scat((self.P00- other), \
208
- (self.S0- other), \
209
- (self.S1- other), \
210
- (self.S2- other), \
211
- (self.S2L- other), \
212
- self.j1,self.j2,backend=self.backend)
213
-
214
- def __rsub__(self,other):
215
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
216
- isinstance(other, bool) or isinstance(other, scat)
217
-
300
+ return scat(
301
+ (self.P00 - other),
302
+ (self.S0 - other),
303
+ (self.S1 - other),
304
+ (self.S2 - other),
305
+ (self.S2L - other),
306
+ self.j1,
307
+ self.j2,
308
+ backend=self.backend,
309
+ )
310
+
311
+ def __rsub__(self, other):
312
+ assert (
313
+ isinstance(other, float)
314
+ or isinstance(other, np.float32)
315
+ or isinstance(other, int)
316
+ or isinstance(other, bool)
317
+ or isinstance(other, scat)
318
+ )
319
+
218
320
  if isinstance(other, scat):
219
- return scat(self.domin(other.P00,self.P00), \
220
- self.domin(other.S0, self.S0), \
221
- self.domin(other.S1, self.S1), \
222
- self.domin(other.S2, self.S2), \
223
- self.domin(other.S2L, self.S2L), \
224
- self.j1,self.j2,backend=self.backend)
321
+ return scat(
322
+ self.domin(other.P00, self.P00),
323
+ self.domin(other.S0, self.S0),
324
+ self.domin(other.S1, self.S1),
325
+ self.domin(other.S2, self.S2),
326
+ self.domin(other.S2L, self.S2L),
327
+ self.j1,
328
+ self.j2,
329
+ backend=self.backend,
330
+ )
225
331
  else:
226
- return scat((other-self.P00), \
227
- (other-self.S0), \
228
- (other-self.S1), \
229
- (other-self.S2), \
230
- (other-self.S2L), \
231
- self.j1,self.j2,backend=self.backend)
232
-
233
- def __mul__(self,other):
234
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
235
- isinstance(other, bool) or isinstance(other, scat)
236
-
332
+ return scat(
333
+ (other - self.P00),
334
+ (other - self.S0),
335
+ (other - self.S1),
336
+ (other - self.S2),
337
+ (other - self.S2L),
338
+ self.j1,
339
+ self.j2,
340
+ backend=self.backend,
341
+ )
342
+
343
+ def __mul__(self, other):
344
+ assert (
345
+ isinstance(other, float)
346
+ or isinstance(other, np.float32)
347
+ or isinstance(other, int)
348
+ or isinstance(other, bool)
349
+ or isinstance(other, scat)
350
+ )
351
+
237
352
  if isinstance(other, scat):
238
- return scat(self.domult(self.P00, other.P00), \
239
- self.domult(self.S0, other.S0), \
240
- self.domult(self.S1, other.S1), \
241
- self.domult(self.S2, other.S2), \
242
- self.domult(self.S2L, other.S2L), \
243
- self.j1,self.j2,backend=self.backend)
353
+ return scat(
354
+ self.domult(self.P00, other.P00),
355
+ self.domult(self.S0, other.S0),
356
+ self.domult(self.S1, other.S1),
357
+ self.domult(self.S2, other.S2),
358
+ self.domult(self.S2L, other.S2L),
359
+ self.j1,
360
+ self.j2,
361
+ backend=self.backend,
362
+ )
244
363
  else:
245
- return scat((self.P00* other), \
246
- (self.S0* other), \
247
- (self.S1* other), \
248
- (self.S2* other), \
249
- (self.S2L* other), \
250
- self.j1,self.j2,backend=self.backend)
364
+ return scat(
365
+ (self.P00 * other),
366
+ (self.S0 * other),
367
+ (self.S1 * other),
368
+ (self.S2 * other),
369
+ (self.S2L * other),
370
+ self.j1,
371
+ self.j2,
372
+ backend=self.backend,
373
+ )
374
+
251
375
  def relu(self):
252
- return scat(self.backend.bk_relu(self.P00),
253
- self.backend.bk_relu(self.S0),
254
- self.backend.bk_relu(self.S1),
255
- self.backend.bk_relu(self.S2),
256
- self.backend.bk_relu(self.S2L), \
257
- self.j1,self.j2,backend=self.backend)
258
-
259
-
260
- def __rmul__(self,other):
261
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
262
- isinstance(other, bool) or isinstance(other, scat)
263
-
376
+ return scat(
377
+ self.backend.bk_relu(self.P00),
378
+ self.backend.bk_relu(self.S0),
379
+ self.backend.bk_relu(self.S1),
380
+ self.backend.bk_relu(self.S2),
381
+ self.backend.bk_relu(self.S2L),
382
+ self.j1,
383
+ self.j2,
384
+ backend=self.backend,
385
+ )
386
+
387
+ def __rmul__(self, other):
388
+ assert (
389
+ isinstance(other, float)
390
+ or isinstance(other, np.float32)
391
+ or isinstance(other, int)
392
+ or isinstance(other, bool)
393
+ or isinstance(other, scat)
394
+ )
395
+
264
396
  if isinstance(other, scat):
265
- return scat(self.domult(self.P00, other.P00), \
266
- self.domult(self.S0, other.S0), \
267
- self.domult(self.S1, other.S1), \
268
- self.domult(self.S2, other.S2), \
269
- self.domult(self.S2L, other.S2L), \
270
- self.j1,self.j2,backend=self.backend)
397
+ return scat(
398
+ self.domult(self.P00, other.P00),
399
+ self.domult(self.S0, other.S0),
400
+ self.domult(self.S1, other.S1),
401
+ self.domult(self.S2, other.S2),
402
+ self.domult(self.S2L, other.S2L),
403
+ self.j1,
404
+ self.j2,
405
+ backend=self.backend,
406
+ )
271
407
  else:
272
- return scat((self.P00* other), \
273
- (self.S0* other), \
274
- (self.S1* other), \
275
- (self.S2* other), \
276
- (self.S2L* other), \
277
- self.j1,self.j2,backend=self.backend)
278
-
279
- def l1_abs(self,x):
280
- y=self.get_np(x)
408
+ return scat(
409
+ (self.P00 * other),
410
+ (self.S0 * other),
411
+ (self.S1 * other),
412
+ (self.S2 * other),
413
+ (self.S2L * other),
414
+ self.j1,
415
+ self.j2,
416
+ backend=self.backend,
417
+ )
418
+
419
+ def l1_abs(self, x):
420
+ y = self.get_np(x)
281
421
  if self.backend.bk_is_complex(y):
282
- tmp=y.real*y.real+y.imag*y.imag
283
- tmp=np.sign(tmp)*np.sqrt(np.fabs(tmp))
284
- y=tmp
285
-
286
- return(y)
287
-
288
- def plot(self,name=None,hold=True,color='blue',lw=1,legend=True):
422
+ tmp = y.real * y.real + y.imag * y.imag
423
+ tmp = np.sign(tmp) * np.sqrt(np.fabs(tmp))
424
+ y = tmp
425
+
426
+ return y
427
+
428
+ def plot(self, name=None, hold=True, color="blue", lw=1, legend=True):
289
429
 
290
430
  import matplotlib.pyplot as plt
291
431
 
292
- j1,j2=self.get_j_idx()
293
-
432
+ j1, j2 = self.get_j_idx()
433
+
294
434
  if name is None:
295
- name=''
435
+ name = ""
296
436
 
297
437
  if hold:
298
- plt.figure(figsize=(16,8))
299
-
300
- test=None
438
+ plt.figure(figsize=(16, 8))
439
+
440
+ test = None
301
441
  plt.subplot(2, 2, 1)
302
- tmp=abs(self.get_np(self.S1))
303
- if len(tmp.shape)==4:
442
+ tmp = abs(self.get_np(self.S1))
443
+ if len(tmp.shape) == 4:
304
444
  for k in range(tmp.shape[3]):
305
445
  for i1 in range(tmp.shape[0]):
306
446
  for i2 in range(tmp.shape[1]):
307
447
  if test is None:
308
- test=1
309
- plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
448
+ test = 1
449
+ plt.plot(
450
+ tmp[i1, i2, :, k],
451
+ color=color,
452
+ label=r"%s $S_{1}$" % (name),
453
+ lw=lw,
454
+ )
310
455
  else:
311
- plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
456
+ plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
312
457
  else:
313
458
  for k in range(tmp.shape[2]):
314
459
  for i1 in range(tmp.shape[0]):
315
460
  if test is None:
316
- test=1
317
- plt.plot(tmp[i1,:,k],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
461
+ test = 1
462
+ plt.plot(
463
+ tmp[i1, :, k],
464
+ color=color,
465
+ label=r"%s $S_{1}$" % (name),
466
+ lw=lw,
467
+ )
318
468
  else:
319
- plt.plot(tmp[i1,:,k],color=color, lw=lw)
320
- plt.yscale('log')
321
- plt.ylabel('S1')
322
- plt.xlabel(r'$j_{1}$')
469
+ plt.plot(tmp[i1, :, k], color=color, lw=lw)
470
+ plt.yscale("log")
471
+ plt.ylabel("S1")
472
+ plt.xlabel(r"$j_{1}$")
323
473
  plt.legend()
324
474
 
325
- test=None
475
+ test = None
326
476
  plt.subplot(2, 2, 2)
327
- tmp=abs(self.get_np(self.P00))
328
- if len(tmp.shape)==4:
477
+ tmp = abs(self.get_np(self.P00))
478
+ if len(tmp.shape) == 4:
329
479
  for k in range(tmp.shape[3]):
330
480
  for i1 in range(tmp.shape[0]):
331
- for i2 in range(tmp.shape[0]):
481
+ for i2 in range(tmp.shape[1]):
332
482
  if test is None:
333
- test=1
334
- plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
483
+ test = 1
484
+ plt.plot(
485
+ tmp[i1, i2, :, k],
486
+ color=color,
487
+ label=r"%s $P_{00}$" % (name),
488
+ lw=lw,
489
+ )
335
490
  else:
336
- plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
491
+ plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
337
492
  else:
338
493
  for k in range(tmp.shape[2]):
339
494
  for i1 in range(tmp.shape[0]):
340
495
  if test is None:
341
- test=1
342
- plt.plot(tmp[i1,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
496
+ test = 1
497
+ plt.plot(
498
+ tmp[i1, :, k],
499
+ color=color,
500
+ label=r"%s $P_{00}$" % (name),
501
+ lw=lw,
502
+ )
343
503
  else:
344
- plt.plot(tmp[i1,:,k],color=color, lw=lw)
345
- plt.yscale('log')
346
- plt.ylabel('P00')
347
- plt.xlabel(r'$j_{1}$')
504
+ plt.plot(tmp[i1, :, k], color=color, lw=lw)
505
+ plt.yscale("log")
506
+ plt.ylabel("P00")
507
+ plt.xlabel(r"$j_{1}$")
348
508
  plt.legend()
349
-
350
- ax1=plt.subplot(2, 2, 3)
509
+
510
+ ax1 = plt.subplot(2, 2, 3)
351
511
  ax2 = ax1.twiny()
352
- n=0
353
- tmp=abs(self.get_np(self.S2))
354
- lname=r'%s $S_{2}$' % (name)
355
- ax1.set_ylabel(r'$S_{2}$ [L1 norm]')
356
- test=None
357
- tabx=[]
358
- tabnx=[]
359
- tab2x=[]
360
- tab2nx=[]
361
- if len(tmp.shape)==5:
512
+ n = 0
513
+ tmp = abs(self.get_np(self.S2))
514
+ lname = r"%s $S_{2}$" % (name)
515
+ ax1.set_ylabel(r"$S_{2}$ [L1 norm]")
516
+ test = None
517
+ tabx = []
518
+ tabnx = []
519
+ tab2x = []
520
+ tab2nx = []
521
+ if len(tmp.shape) == 5:
362
522
  for i0 in range(tmp.shape[0]):
363
523
  for i1 in range(tmp.shape[1]):
364
- for i2 in range(j1.max()+1):
524
+ for i2 in range(j1.max() + 1):
365
525
  for i3 in range(tmp.shape[3]):
366
526
  for i4 in range(tmp.shape[4]):
367
- if j2[j1==i2].shape[0]==1:
368
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4],'.', \
369
- color=color, lw=lw)
527
+ if j2[j1 == i2].shape[0] == 1:
528
+ ax1.plot(
529
+ j2[j1 == i2] + n,
530
+ tmp[i0, i1, j1 == i2, i3, i4],
531
+ ".",
532
+ color=color,
533
+ lw=lw,
534
+ )
370
535
  else:
371
536
  if legend and test is None:
372
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
373
- color=color, label=lname, lw=lw)
374
- test=1
375
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
376
- color=color, lw=lw)
377
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
378
- tabx=tabx+[k+n for k in j2[j1==i2]]
379
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
380
- tab2nx=tab2nx+['%d'%(i2)]
381
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
382
- n=n+j2[j1==i2].shape[0]-1
537
+ ax1.plot(
538
+ j2[j1 == i2] + n,
539
+ tmp[i0, i1, j1 == i2, i3, i4],
540
+ color=color,
541
+ label=lname,
542
+ lw=lw,
543
+ )
544
+ test = 1
545
+ ax1.plot(
546
+ j2[j1 == i2] + n,
547
+ tmp[i0, i1, j1 == i2, i3, i4],
548
+ color=color,
549
+ lw=lw,
550
+ )
551
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
552
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
553
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
554
+ tab2nx = tab2nx + ["%d" % (i2)]
555
+ ax1.axvline(
556
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
557
+ )
558
+ n = n + j2[j1 == i2].shape[0] - 1
383
559
  else:
384
560
  for i0 in range(tmp.shape[0]):
385
- for i2 in range(j1.max()+1):
561
+ for i2 in range(j1.max() + 1):
386
562
  for i3 in range(tmp.shape[2]):
387
563
  for i4 in range(tmp.shape[3]):
388
- if j2[j1==i2].shape[0]==1:
389
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4],'.', \
390
- color=color, lw=lw)
564
+ if j2[j1 == i2].shape[0] == 1:
565
+ ax1.plot(
566
+ j2[j1 == i2] + n,
567
+ tmp[i0, j1 == i2, i3, i4],
568
+ ".",
569
+ color=color,
570
+ lw=lw,
571
+ )
391
572
  else:
392
573
  if legend and test is None:
393
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
394
- color=color, label=lname, lw=lw)
395
- test=1
396
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
397
- color=color, lw=lw)
398
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
399
- tabx=tabx+[k+n for k in j2[j1==i2]]
400
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
401
- tab2nx=tab2nx+['%d'%(i2)]
402
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
403
- n=n+j2[j1==i2].shape[0]-1
404
- plt.yscale('log')
405
- ax1.set_xlim(0,n+2)
574
+ ax1.plot(
575
+ j2[j1 == i2] + n,
576
+ tmp[i0, j1 == i2, i3, i4],
577
+ color=color,
578
+ label=lname,
579
+ lw=lw,
580
+ )
581
+ test = 1
582
+ ax1.plot(
583
+ j2[j1 == i2] + n,
584
+ tmp[i0, j1 == i2, i3, i4],
585
+ color=color,
586
+ lw=lw,
587
+ )
588
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
589
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
590
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
591
+ tab2nx = tab2nx + ["%d" % (i2)]
592
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
593
+ n = n + j2[j1 == i2].shape[0] - 1
594
+ plt.yscale("log")
595
+ ax1.set_xlim(0, n + 2)
406
596
  ax1.set_xticks(tabx)
407
- ax1.set_xticklabels(tabnx,fontsize=6)
408
- ax1.set_xlabel(r"$j_{2}$ ",fontsize=6)
409
-
597
+ ax1.set_xticklabels(tabnx, fontsize=6)
598
+ ax1.set_xlabel(r"$j_{2}$ ", fontsize=6)
599
+
410
600
  # Move twinned axis ticks and label from top to bottom
411
601
  ax2.xaxis.set_ticks_position("bottom")
412
602
  ax2.xaxis.set_label_position("bottom")
@@ -414,7 +604,7 @@ class scat:
414
604
  # Offset the twin axis below the host
415
605
  ax2.spines["bottom"].set_position(("axes", -0.15))
416
606
 
417
- # Turn on the frame for the twin axis, but then hide all
607
+ # Turn on the frame for the twin axis, but then hide all
418
608
  # but the bottom spine
419
609
  ax2.set_frame_on(True)
420
610
  ax2.patch.set_visible(False)
@@ -422,72 +612,102 @@ class scat:
422
612
  for sp in ax2.spines.values():
423
613
  sp.set_visible(False)
424
614
  ax2.spines["bottom"].set_visible(True)
425
- ax2.set_xlim(0,n+2)
615
+ ax2.set_xlim(0, n + 2)
426
616
  ax2.set_xticks(tab2x)
427
- ax2.set_xticklabels(tab2nx,fontsize=6)
428
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
617
+ ax2.set_xticklabels(tab2nx, fontsize=6)
618
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
429
619
  ax1.legend(frameon=0)
430
-
431
- ax1=plt.subplot(2, 2, 4)
620
+
621
+ ax1 = plt.subplot(2, 2, 4)
432
622
  ax2 = ax1.twiny()
433
- n=0
434
- tmp=abs(self.get_np(self.S2L))
435
- lname=r'%s $S2_{2}$' % (name)
436
- ax1.set_ylabel(r'$S_{2}$ [L2 norm]')
437
- test=None
438
- tabx=[]
439
- tabnx=[]
440
- tab2x=[]
441
- tab2nx=[]
442
- if len(tmp.shape)==5:
623
+ n = 0
624
+ tmp = abs(self.get_np(self.S2L))
625
+ lname = r"%s $S2_{2}$" % (name)
626
+ ax1.set_ylabel(r"$S_{2}$ [L2 norm]")
627
+ test = None
628
+ tabx = []
629
+ tabnx = []
630
+ tab2x = []
631
+ tab2nx = []
632
+ if len(tmp.shape) == 5:
443
633
  for i0 in range(tmp.shape[0]):
444
634
  for i1 in range(tmp.shape[1]):
445
- for i2 in range(j1.max()+1):
635
+ for i2 in range(j1.max() + 1):
446
636
  for i3 in range(tmp.shape[3]):
447
637
  for i4 in range(tmp.shape[4]):
448
- if j2[j1==i2].shape[0]==1:
449
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4],'.', \
450
- color=color, lw=lw)
638
+ if j2[j1 == i2].shape[0] == 1:
639
+ ax1.plot(
640
+ j2[j1 == i2] + n,
641
+ tmp[i0, i1, j1 == i2, i3, i4],
642
+ ".",
643
+ color=color,
644
+ lw=lw,
645
+ )
451
646
  else:
452
647
  if legend and test is None:
453
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
454
- color=color, label=lname, lw=lw)
455
- test=1
456
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
457
- color=color, lw=lw)
458
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
459
- tabx=tabx+[k+n for k in j2[j1==i2]]
460
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
461
- tab2nx=tab2nx+['%d'%(i2)]
462
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
463
- n=n+j2[j1==i2].shape[0]-1
648
+ ax1.plot(
649
+ j2[j1 == i2] + n,
650
+ tmp[i0, i1, j1 == i2, i3, i4],
651
+ color=color,
652
+ label=lname,
653
+ lw=lw,
654
+ )
655
+ test = 1
656
+ ax1.plot(
657
+ j2[j1 == i2] + n,
658
+ tmp[i0, i1, j1 == i2, i3, i4],
659
+ color=color,
660
+ lw=lw,
661
+ )
662
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
663
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
664
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
665
+ tab2nx = tab2nx + ["%d" % (i2)]
666
+ ax1.axvline(
667
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
668
+ )
669
+ n = n + j2[j1 == i2].shape[0] - 1
464
670
  else:
465
671
  for i0 in range(tmp.shape[0]):
466
- for i2 in range(j1.max()+1):
672
+ for i2 in range(j1.max() + 1):
467
673
  for i3 in range(tmp.shape[2]):
468
674
  for i4 in range(tmp.shape[3]):
469
- if j2[j1==i2].shape[0]==1:
470
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4],'.', \
471
- color=color, lw=lw)
675
+ if j2[j1 == i2].shape[0] == 1:
676
+ ax1.plot(
677
+ j2[j1 == i2] + n,
678
+ tmp[i0, j1 == i2, i3, i4],
679
+ ".",
680
+ color=color,
681
+ lw=lw,
682
+ )
472
683
  else:
473
684
  if legend and test is None:
474
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
475
- color=color, label=lname, lw=lw)
476
- test=1
477
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
478
- color=color, lw=lw)
479
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
480
- tabx=tabx+[k+n for k in j2[j1==i2]]
481
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
482
- tab2nx=tab2nx+['%d'%(i2)]
483
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
484
- n=n+j2[j1==i2].shape[0]-1
485
- plt.yscale('log')
486
- ax1.set_xlim(-1,n+3)
685
+ ax1.plot(
686
+ j2[j1 == i2] + n,
687
+ tmp[i0, j1 == i2, i3, i4],
688
+ color=color,
689
+ label=lname,
690
+ lw=lw,
691
+ )
692
+ test = 1
693
+ ax1.plot(
694
+ j2[j1 == i2] + n,
695
+ tmp[i0, j1 == i2, i3, i4],
696
+ color=color,
697
+ lw=lw,
698
+ )
699
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
700
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
701
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
702
+ tab2nx = tab2nx + ["%d" % (i2)]
703
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
704
+ n = n + j2[j1 == i2].shape[0] - 1
705
+ plt.yscale("log")
706
+ ax1.set_xlim(-1, n + 3)
487
707
  ax1.set_xticks(tabx)
488
- ax1.set_xticklabels(tabnx,fontsize=6)
489
- ax1.set_xlabel(r"$j_{2}$",fontsize=6)
490
-
708
+ ax1.set_xticklabels(tabnx, fontsize=6)
709
+ ax1.set_xlabel(r"$j_{2}$", fontsize=6)
710
+
491
711
  # Move twinned axis ticks and label from top to bottom
492
712
  ax2.xaxis.set_ticks_position("bottom")
493
713
  ax2.xaxis.set_label_position("bottom")
@@ -495,7 +715,7 @@ class scat:
495
715
  # Offset the twin axis below the host
496
716
  ax2.spines["bottom"].set_position(("axes", -0.15))
497
717
 
498
- # Turn on the frame for the twin axis, but then hide all
718
+ # Turn on the frame for the twin axis, but then hide all
499
719
  # but the bottom spine
500
720
  ax2.set_frame_on(True)
501
721
  ax2.patch.set_visible(False)
@@ -503,252 +723,351 @@ class scat:
503
723
  for sp in ax2.spines.values():
504
724
  sp.set_visible(False)
505
725
  ax2.spines["bottom"].set_visible(True)
506
- ax2.set_xlim(0,n+3)
726
+ ax2.set_xlim(0, n + 3)
507
727
  ax2.set_xticks(tab2x)
508
- ax2.set_xticklabels(tab2nx,fontsize=6)
509
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
728
+ ax2.set_xticklabels(tab2nx, fontsize=6)
729
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
510
730
  ax1.legend(frameon=0)
511
-
512
- def save(self,filename):
513
- outlist=[self.get_S0().numpy(), \
514
- self.get_S1().numpy(), \
515
- self.get_S2().numpy(), \
516
- self.get_S2L().numpy(), \
517
- self.get_P00().numpy(), \
518
- self.j1, \
519
- self.j2]
520
-
521
- myout=open("%s.pkl"%(filename),"wb")
522
- pickle.dump(outlist,myout)
731
+
732
+ def save(self, filename):
733
+ outlist = [
734
+ self.get_S0().numpy(),
735
+ self.get_S1().numpy(),
736
+ self.get_S2().numpy(),
737
+ self.get_S2L().numpy(),
738
+ self.get_P00().numpy(),
739
+ self.j1,
740
+ self.j2,
741
+ ]
742
+
743
+ myout = open("%s.pkl" % (filename), "wb")
744
+ pickle.dump(outlist, myout)
523
745
  myout.close()
524
746
 
525
-
526
- def read(self,filename):
527
-
528
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
529
- return scat(outlist[4],outlist[0],outlist[1],outlist[2],outlist[3],outlist[5],outlist[6],backend=bk.foscat_backend('numpy'))
530
-
531
- def get_np(self,x):
747
+ def read(self, filename):
748
+
749
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
750
+ return scat(
751
+ outlist[4],
752
+ outlist[0],
753
+ outlist[1],
754
+ outlist[2],
755
+ outlist[3],
756
+ outlist[5],
757
+ outlist[6],
758
+ backend=bk.foscat_backend("numpy"),
759
+ )
760
+
761
+ def get_np(self, x):
532
762
  if isinstance(x, np.ndarray):
533
763
  return x
534
764
  else:
535
765
  return x.numpy()
536
766
 
537
767
  def std(self):
538
- return np.sqrt(((abs(self.get_np(self.S0)).std())**2+ \
539
- (abs(self.get_np(self.S1)).std())**2+ \
540
- (abs(self.get_np(self.S2)).std())**2+ \
541
- (abs(self.get_np(self.S2L)).std())**2+ \
542
- (abs(self.get_np(self.P00)).std())**2)/4)
768
+ return np.sqrt(
769
+ (
770
+ (abs(self.get_np(self.S0)).std()) ** 2
771
+ + (abs(self.get_np(self.S1)).std()) ** 2
772
+ + (abs(self.get_np(self.S2)).std()) ** 2
773
+ + (abs(self.get_np(self.S2L)).std()) ** 2
774
+ + (abs(self.get_np(self.P00)).std()) ** 2
775
+ )
776
+ / 4
777
+ )
543
778
 
544
779
  def mean(self):
545
- return abs(self.get_np(self.S0).mean()+ \
546
- self.get_np(self.S1).mean()+ \
547
- self.get_np(self.S2).mean()+ \
548
- self.get_np(self.S2L).mean()+ \
549
- self.get_np(self.P00).mean())/3
780
+ return (
781
+ abs(
782
+ self.get_np(self.S0).mean()
783
+ + self.get_np(self.S1).mean()
784
+ + self.get_np(self.S2).mean()
785
+ + self.get_np(self.S2L).mean()
786
+ + self.get_np(self.P00).mean()
787
+ )
788
+ / 3
789
+ )
550
790
 
551
791
  def sqrt(self):
552
792
 
793
+ s0 = self.backend.bk_sqrt(self.S0)
794
+ s1 = self.backend.bk_sqrt(self.S1)
795
+ p00 = self.backend.bk_sqrt(self.P00)
796
+ s2 = self.backend.bk_sqrt(self.S2)
797
+ s2L = self.backend.bk_sqrt(self.S2L)
553
798
 
554
- s0 =self.backend.bk_sqrt(self.S0)
555
- s1 =self.backend.bk_sqrt(self.S1)
556
- p00=self.backend.bk_sqrt(self.P00)
557
- s2 =self.backend.bk_sqrt(self.S2)
558
- s2L=self.backend.bk_sqrt(self.S2L)
559
-
560
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
561
-
799
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
562
800
 
563
801
  def L1(self):
564
802
 
803
+ s0 = self.backend.bk_L1(self.S0)
804
+ s1 = self.backend.bk_L1(self.S1)
805
+ p00 = self.backend.bk_L1(self.P00)
806
+ s2 = self.backend.bk_L1(self.S2)
807
+ s2L = self.backend.bk_L1(self.S2L)
808
+
809
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
565
810
 
566
- s0 =self.backend.bk_L1(self.S0)
567
- s1 =self.backend.bk_L1(self.S1)
568
- p00=self.backend.bk_L1(self.P00)
569
- s2 =self.backend.bk_L1(self.S2)
570
- s2L=self.backend.bk_L1(self.S2L)
571
-
572
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
573
-
574
811
  def square_comp(self):
575
812
 
813
+ s0 = self.backend.bk_square_comp(self.S0)
814
+ s1 = self.backend.bk_square_comp(self.S1)
815
+ p00 = self.backend.bk_square_comp(self.P00)
816
+ s2 = self.backend.bk_square_comp(self.S2)
817
+ s2L = self.backend.bk_square_comp(self.S2L)
818
+
819
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
576
820
 
577
- s0 =self.backend.bk_square_comp(self.S0)
578
- s1 =self.backend.bk_square_comp(self.S1)
579
- p00=self.backend.bk_square_comp(self.P00)
580
- s2 =self.backend.bk_square_comp(self.S2)
581
- s2L=self.backend.bk_square_comp(self.S2L)
582
-
583
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
584
-
585
- def iso_mean(self,repeat=False):
586
- shape=list(self.S2.shape)
587
- norient=self.S1.shape[2]
821
+ def iso_mean(self, repeat=False):
822
+ shape = list(self.S2.shape)
823
+ norient = self.S1.shape[2]
588
824
 
589
- S1 = self.backend.bk_reduce_mean(self.S1,2)
825
+ S1 = self.backend.bk_reduce_mean(self.S1, 2)
590
826
  if repeat:
591
- S1=self.backend.bk_reshape(self.backend.bk_repeat(S1,shape[2],1),self.S1.shape)
827
+ S1 = self.backend.bk_reshape(
828
+ self.backend.bk_repeat(S1, shape[2], 1), self.S1.shape
829
+ )
592
830
  else:
593
- S1=self.backend.bk_expand_dims(S1,-1)
594
-
831
+ S1 = self.backend.bk_expand_dims(S1, -1)
595
832
 
596
- P00 = self.backend.bk_reduce_mean(self.P00,2)
833
+ P00 = self.backend.bk_reduce_mean(self.P00, 2)
597
834
  if repeat:
598
- P00=self.backend.bk_reshape(self.backend.bk_repeat(P00,shape[2],1),self.S1.shape)
835
+ P00 = self.backend.bk_reshape(
836
+ self.backend.bk_repeat(P00, shape[2], 1), self.S1.shape
837
+ )
599
838
  else:
600
- P00=self.backend.bk_expand_dims(P00,-1)
839
+ P00 = self.backend.bk_expand_dims(P00, -1)
601
840
 
602
841
  if norient not in self.backend._iso_orient:
603
842
  self.backend.calc_iso_orient(norient)
604
-
843
+
605
844
  if self.backend.bk_is_complex(self.S2):
606
- lmat = self.backend._iso_orient_C[norient]
845
+ lmat = self.backend._iso_orient_C[norient]
607
846
  lmat_T = self.backend._iso_orient_C_T[norient]
608
847
  else:
609
- lmat = self.backend._iso_orient[norient]
848
+ lmat = self.backend._iso_orient[norient]
610
849
  lmat_T = self.backend._iso_orient_T[norient]
611
-
612
- S2=self.backend.bk_reshape(
613
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2,[shape[0],shape[1],norient*norient]),lmat),
614
- [shape[0],shape[1],norient])
615
- S2L=self.backend.bk_reshape(
616
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2L,[shape[0],shape[1],norient*norient]),lmat),
617
- [shape[0],shape[1],norient])
618
-
850
+
851
+ S2 = self.backend.bk_reshape(
852
+ self.backend.backend.matmul(
853
+ self.backend.bk_reshape(
854
+ self.S2, [shape[0], shape[1], norient * norient]
855
+ ),
856
+ lmat,
857
+ ),
858
+ [shape[0], shape[1], norient],
859
+ )
860
+ S2L = self.backend.bk_reshape(
861
+ self.backend.backend.matmul(
862
+ self.backend.bk_reshape(
863
+ self.S2L, [shape[0], shape[1], norient * norient]
864
+ ),
865
+ lmat,
866
+ ),
867
+ [shape[0], shape[1], norient],
868
+ )
869
+
619
870
  if repeat:
620
-
621
- S2=self.backend.bk_reshape(
622
- self.backend.backend.matmul(self.backend.bk_reshape(S2,[shape[0]*shape[1],norient]),lmat_T),
623
- self.S2.shape)
624
- S2L=self.backend.bk_reshape(
625
- self.backend.backend.matmul(self.backend.bk_reshape(S2L,[shape[0]*shape[1],norient]),lmat_T),
626
- self.S2.shape)
871
+
872
+ S2 = self.backend.bk_reshape(
873
+ self.backend.backend.matmul(
874
+ self.backend.bk_reshape(S2, [shape[0] * shape[1], norient]), lmat_T
875
+ ),
876
+ self.S2.shape,
877
+ )
878
+ S2L = self.backend.bk_reshape(
879
+ self.backend.backend.matmul(
880
+ self.backend.bk_reshape(S2L, [shape[0] * shape[1], norient]), lmat_T
881
+ ),
882
+ self.S2.shape,
883
+ )
627
884
  else:
628
- S2=self.backend.bk_expand_dims(S2,-1)
629
- S2L=self.backend.bk_expand_dims(S2L,-1)
885
+ S2 = self.backend.bk_expand_dims(S2, -1)
886
+ S2L = self.backend.bk_expand_dims(S2L, -1)
630
887
 
631
- return scat(P00,self.S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
888
+ return scat(P00, self.S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
632
889
 
633
-
634
- def fft_ang(self,nharm=1,imaginary=False):
635
- shape=list(self.S2.shape)
636
- norient=self.S1.shape[2]
890
+ def fft_ang(self, nharm=1, imaginary=False):
891
+ shape = list(self.S2.shape)
892
+ norient = self.S1.shape[2]
637
893
 
638
- nout=1+nharm
894
+ nout = 1 + nharm
639
895
  if imaginary:
640
- nout=1+nharm*2
641
-
642
- if (norient,nharm) not in self.backend._fft_1_orient:
643
- self.backend.calc_fft_orient(norient,nharm,imaginary)
644
-
896
+ nout = 1 + nharm * 2
897
+
898
+ if (norient, nharm) not in self.backend._fft_1_orient:
899
+ self.backend.calc_fft_orient(norient, nharm, imaginary)
900
+
645
901
  if self.backend.bk_is_complex(self.S1):
646
- lmat = self.backend._fft_1_orient_C[(norient,nharm,imaginary)]
902
+ lmat = self.backend._fft_1_orient_C[(norient, nharm, imaginary)]
647
903
  else:
648
- lmat = self.backend._fft_1_orient[(norient,nharm,imaginary)]
649
-
650
- S1=self.backend.bk_reshape(
651
- self.backend.backend.matmul(self.backend.bk_reshape(self.S1,[self.S1.shape[0],self.S1.shape[1],norient]),lmat),
652
- [self.S1.shape[0],self.S1.shape[1],nout])
653
-
654
- P00=self.backend.bk_reshape(
655
- self.backend.backend.matmul(self.backend.bk_reshape(self.P00,[self.S1.shape[0],self.S1.shape[1],norient]),lmat),
656
- [self.S1.shape[0],self.S1.shape[1],nout])
657
-
658
-
904
+ lmat = self.backend._fft_1_orient[(norient, nharm, imaginary)]
905
+
906
+ S1 = self.backend.bk_reshape(
907
+ self.backend.backend.matmul(
908
+ self.backend.bk_reshape(
909
+ self.S1, [self.S1.shape[0], self.S1.shape[1], norient]
910
+ ),
911
+ lmat,
912
+ ),
913
+ [self.S1.shape[0], self.S1.shape[1], nout],
914
+ )
915
+
916
+ P00 = self.backend.bk_reshape(
917
+ self.backend.backend.matmul(
918
+ self.backend.bk_reshape(
919
+ self.P00, [self.S1.shape[0], self.S1.shape[1], norient]
920
+ ),
921
+ lmat,
922
+ ),
923
+ [self.S1.shape[0], self.S1.shape[1], nout],
924
+ )
925
+
659
926
  if self.backend.bk_is_complex(self.S2):
660
- lmat = self.backend._fft_2_orient_C[(norient,nharm,imaginary)]
927
+ lmat = self.backend._fft_2_orient_C[(norient, nharm, imaginary)]
661
928
  else:
662
- lmat = self.backend._fft_2_orient[(norient,nharm,imaginary)]
663
-
664
- S2=self.backend.bk_reshape(
665
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2,[shape[0],shape[1],norient*norient]),lmat),
666
- [shape[0],shape[1],nout,nout])
667
- S2L=self.backend.bk_reshape(
668
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2L,[shape[0],shape[1],norient*norient]),lmat),
669
- [shape[0],shape[1],nout,nout])
670
-
671
- return scat(P00,self.S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
672
-
673
-
674
- def iso_std(self,repeat=False):
675
-
676
- val=(self-self.iso_mean(repeat=True)).square_comp()
929
+ lmat = self.backend._fft_2_orient[(norient, nharm, imaginary)]
930
+
931
+ S2 = self.backend.bk_reshape(
932
+ self.backend.backend.matmul(
933
+ self.backend.bk_reshape(
934
+ self.S2, [shape[0], shape[1], norient * norient]
935
+ ),
936
+ lmat,
937
+ ),
938
+ [shape[0], shape[1], nout, nout],
939
+ )
940
+ S2L = self.backend.bk_reshape(
941
+ self.backend.backend.matmul(
942
+ self.backend.bk_reshape(
943
+ self.S2L, [shape[0], shape[1], norient * norient]
944
+ ),
945
+ lmat,
946
+ ),
947
+ [shape[0], shape[1], nout, nout],
948
+ )
949
+
950
+ return scat(P00, self.S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
951
+
952
+ def iso_std(self, repeat=False):
953
+
954
+ val = (self - self.iso_mean(repeat=True)).square_comp()
677
955
  return (val.iso_mean(repeat=repeat)).L1()
678
956
 
679
957
  # ---------------------------------------------−---------
680
- def cleanval(self,x):
681
- x=x.numpy()
682
- x[np.isfinite(x)==False]=np.median(x[np.isfinite(x)])
958
+ def cleanval(self, x):
959
+ x = x.numpy()
960
+ x[~np.isfinite(x)] = np.median(x[np.isfinite(x)])
683
961
  return x
684
962
 
685
963
  def filter_inf(self):
686
- S1 = self.cleanval(self.S1)
687
- S0 = self.cleanval(self.S0)
964
+ S1 = self.cleanval(self.S1)
965
+ S0 = self.cleanval(self.S0)
688
966
  P00 = self.cleanval(self.P00)
689
- S2 = self.cleanval(self.S2)
967
+ S2 = self.cleanval(self.S2)
690
968
  S2L = self.cleanval(self.S2L)
691
969
 
692
- return scat(P00,S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
970
+ return scat(P00, S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
693
971
 
694
972
  # ---------------------------------------------−---------
695
- def interp(self,nscale,extend=False,constant=False,threshold=1E30,use_mask=False):
696
-
697
- if nscale+2>self.S1.shape[1]:
698
- print('Can not *interp* %d with a statistic described over %d'%(nscale,self.S1.shape[1]))
699
- return scat(self.P00,self.S0,self.S1,self.S2,self.S2L,self.j1,self.j2,backend=self.backend)
700
- if isinstance(self.S1,np.ndarray):
701
- s1=self.S1
702
- p0=self.P00
703
- s2=self.S2
704
- s2l=self.S2L
973
+ def interp(
974
+ self, nscale, extend=False, constant=False, threshold=1e30, use_mask=False
975
+ ):
976
+
977
+ if nscale + 2 > self.S1.shape[1]:
978
+ print(
979
+ "Can not *interp* %d with a statistic described over %d"
980
+ % (nscale, self.S1.shape[1])
981
+ )
982
+ return scat(
983
+ self.P00,
984
+ self.S0,
985
+ self.S1,
986
+ self.S2,
987
+ self.S2L,
988
+ self.j1,
989
+ self.j2,
990
+ backend=self.backend,
991
+ )
992
+ if isinstance(self.S1, np.ndarray):
993
+ s1 = self.S1
994
+ p0 = self.P00
995
+ s2 = self.S2
996
+ s2l = self.S2L
705
997
  else:
706
- s1=self.S1.numpy()
707
- p0=self.P00.numpy()
708
- s2=self.S2.numpy()
709
- s2l=self.S2L.numpy()
710
-
711
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
712
-
713
- if isinstance(threshold,scat):
714
- if isinstance(threshold.S1,np.ndarray):
715
- s1th=threshold.S1
716
- p0th=threshold.P00
717
- s2th=threshold.S2
718
- s2lth=threshold.S2L
998
+ s1 = self.S1.numpy()
999
+ p0 = self.P00.numpy()
1000
+ s2 = self.S2.numpy()
1001
+ s2l = self.S2L.numpy()
1002
+
1003
+ print(s1.sum(), p0.sum(), s2.sum(), s2l.sum())
1004
+
1005
+ if isinstance(threshold, scat):
1006
+ if isinstance(threshold.S1, np.ndarray):
1007
+ s1th = threshold.S1
1008
+ p0th = threshold.P00
1009
+ s2th = threshold.S2
1010
+ s2lth = threshold.S2L
719
1011
  else:
720
- s1th=threshold.S1.numpy()
721
- p0th=threshold.P00.numpy()
722
- s2th=threshold.S2.numpy()
723
- s2lth=threshold.S2L.numpy()
1012
+ s1th = threshold.S1.numpy()
1013
+ p0th = threshold.P00.numpy()
1014
+ s2th = threshold.S2.numpy()
1015
+ s2lth = threshold.S2L.numpy()
724
1016
  else:
725
- s1th=threshold+0*s1
726
- p0th=threshold+0*p0
727
- s2th=threshold+0*s2
728
- s2lth=threshold+0*s2l
1017
+ s1th = threshold + 0 * s1
1018
+ p0th = threshold + 0 * p0
1019
+ s2th = threshold + 0 * s2
1020
+ s2lth = threshold + 0 * s2l
729
1021
 
730
1022
  for k in range(nscale):
731
1023
  if constant:
732
- s1[:,nscale-1-k,:]=s1[:,nscale-k,:]
733
- p0[:,nscale-1-k,:]=p0[:,nscale-k,:]
1024
+ s1[:, nscale - 1 - k, :] = s1[:, nscale - k, :]
1025
+ p0[:, nscale - 1 - k, :] = p0[:, nscale - k, :]
734
1026
  else:
735
- idx=np.where((s1[:,nscale+1-k,:]>0)*(s1[:,nscale+2-k,:]>0)*(s1[:,nscale-k,:]<s1th[:,nscale-k,:]))
736
- if len(idx[0])>0:
737
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(s1[idx[0],nscale+1-k,idx[1]])-2*np.log(s1[idx[0],nscale+2-k,idx[1]]))
738
- idx=np.where((s1[:,nscale-k,:]>0)*(s1[:,nscale+1-k,:]>0)*(s1[:,nscale-1-k,:]<s1th[:,nscale-1-k,:]))
739
- if len(idx[0])>0:
740
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(s1[idx[0],nscale-k,idx[1]])-np.log(s1[idx[0],nscale+1-k,idx[1]]))
741
-
742
- idx=np.where((p0[:,nscale+1-k,:]>0)*(p0[:,nscale+2-k,:]>0)*(p0[:,nscale-k,:]<p0th[:,nscale-k,:]))
743
- if len(idx[0])>0:
744
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(p0[idx[0],nscale+1-k,idx[1]])-2*np.log(p0[idx[0],nscale+2-k,idx[1]]))
745
-
746
- idx=np.where((p0[:,nscale-k,:]>0)*(p0[:,nscale+1-k,:]>0)*(p0[:,nscale-1-k,:]<p0th[:,nscale-1-k,:]))
747
- if len(idx[0])>0:
748
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(p0[idx[0],nscale-k,idx[1]])-np.log(p0[idx[0],nscale+1-k,idx[1]]))
749
-
750
-
751
- j1,j2=self.get_j_idx()
1027
+ idx = np.where(
1028
+ (s1[:, nscale + 1 - k, :] > 0)
1029
+ * (s1[:, nscale + 2 - k, :] > 0)
1030
+ * (s1[:, nscale - k, :] < s1th[:, nscale - k, :])
1031
+ )
1032
+ if len(idx[0]) > 0:
1033
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1034
+ 3 * np.log(s1[idx[0], nscale + 1 - k, idx[1]])
1035
+ - 2 * np.log(s1[idx[0], nscale + 2 - k, idx[1]])
1036
+ )
1037
+ idx = np.where(
1038
+ (s1[:, nscale - k, :] > 0)
1039
+ * (s1[:, nscale + 1 - k, :] > 0)
1040
+ * (s1[:, nscale - 1 - k, :] < s1th[:, nscale - 1 - k, :])
1041
+ )
1042
+ if len(idx[0]) > 0:
1043
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1044
+ 2 * np.log(s1[idx[0], nscale - k, idx[1]])
1045
+ - np.log(s1[idx[0], nscale + 1 - k, idx[1]])
1046
+ )
1047
+
1048
+ idx = np.where(
1049
+ (p0[:, nscale + 1 - k, :] > 0)
1050
+ * (p0[:, nscale + 2 - k, :] > 0)
1051
+ * (p0[:, nscale - k, :] < p0th[:, nscale - k, :])
1052
+ )
1053
+ if len(idx[0]) > 0:
1054
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1055
+ 3 * np.log(p0[idx[0], nscale + 1 - k, idx[1]])
1056
+ - 2 * np.log(p0[idx[0], nscale + 2 - k, idx[1]])
1057
+ )
1058
+
1059
+ idx = np.where(
1060
+ (p0[:, nscale - k, :] > 0)
1061
+ * (p0[:, nscale + 1 - k, :] > 0)
1062
+ * (p0[:, nscale - 1 - k, :] < p0th[:, nscale - 1 - k, :])
1063
+ )
1064
+ if len(idx[0]) > 0:
1065
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1066
+ 2 * np.log(p0[idx[0], nscale - k, idx[1]])
1067
+ - np.log(p0[idx[0], nscale + 1 - k, idx[1]])
1068
+ )
1069
+
1070
+ j1, j2 = self.get_j_idx()
752
1071
 
753
1072
  for k in range(nscale):
754
1073
 
@@ -761,646 +1080,951 @@ class scat:
761
1080
  s2l[:,i0]=np.exp(2*np.log(s2l[:,i1])-np.log(s2l[:,i2]))
762
1081
  """
763
1082
 
764
- for l in range(nscale-k):
765
- i0=np.where((j1==nscale-1-k-l)*(j2==nscale-1-k))[0]
766
- i1=np.where((j1==nscale-1-k-l)*(j2==nscale -k))[0]
767
- i2=np.where((j1==nscale-1-k-l)*(j2==nscale+1-k))[0]
768
- i3=np.where((j1==nscale-1-k-l)*(j2==nscale+2-k))[0]
769
-
1083
+ for l_scale in range(nscale - k):
1084
+ i0 = np.where(
1085
+ (j1 == nscale - 1 - k - l_scale) * (j2 == nscale - 1 - k)
1086
+ )[0]
1087
+ i1 = np.where((j1 == nscale - 1 - k - l_scale) * (j2 == nscale - k))[0]
1088
+ i2 = np.where(
1089
+ (j1 == nscale - 1 - k - l_scale) * (j2 == nscale + 1 - k)
1090
+ )[0]
1091
+ i3 = np.where(
1092
+ (j1 == nscale - 1 - k - l_scale) * (j2 == nscale + 2 - k)
1093
+ )[0]
1094
+
770
1095
  if constant:
771
- s2[:,i0]=s2[:,i1]
772
- s2l[:,i0]=s2l[:,i1]
1096
+ s2[:, i0] = s2[:, i1]
1097
+ s2l[:, i0] = s2l[:, i1]
773
1098
  else:
774
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
775
- if len(idx[0])>0:
776
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
777
-
778
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
779
- if len(idx[0])>0:
780
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
781
-
782
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
783
- if len(idx[0])>0:
784
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
785
-
786
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
787
- if len(idx[0])>0:
788
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
789
-
1099
+ idx = np.where(
1100
+ (s2[:, i2] > 0) * (s2[:, i3] > 0) * (s2[:, i2] < s2th[:, i2])
1101
+ )
1102
+ if len(idx[0]) > 0:
1103
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1104
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
1105
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
1106
+ )
1107
+
1108
+ idx = np.where(
1109
+ (s2[:, i1] > 0) * (s2[:, i2] > 0) * (s2[:, i1] < s2th[:, i1])
1110
+ )
1111
+ if len(idx[0]) > 0:
1112
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1113
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1114
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1115
+ )
1116
+
1117
+ idx = np.where(
1118
+ (s2l[:, i2] > 0)
1119
+ * (s2l[:, i3] > 0)
1120
+ * (s2l[:, i2] < s2lth[:, i2])
1121
+ )
1122
+ if len(idx[0]) > 0:
1123
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1124
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1125
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1126
+ )
1127
+
1128
+ idx = np.where(
1129
+ (s2l[:, i1] > 0)
1130
+ * (s2l[:, i2] > 0)
1131
+ * (s2l[:, i1] < s2lth[:, i1])
1132
+ )
1133
+ if len(idx[0]) > 0:
1134
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1135
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1136
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1137
+ )
1138
+
790
1139
  if extend:
791
1140
  for k in range(nscale):
792
- for l in range(1,nscale):
793
- i0=np.where((j1==2*nscale-1-k)*(j2==2*nscale-1-k-l))[0]
794
- i1=np.where((j1==2*nscale-1-k)*(j2==2*nscale -k-l))[0]
795
- i2=np.where((j1==2*nscale-1-k)*(j2==2*nscale+1-k-l))[0]
796
- i3=np.where((j1==2*nscale-1-k)*(j2==2*nscale+2-k-l))[0]
1141
+ for l_scale in range(1, nscale):
1142
+ i0 = np.where(
1143
+ (j1 == 2 * nscale - 1 - k)
1144
+ * (j2 == 2 * nscale - 1 - k - l_scale)
1145
+ )[0]
1146
+ i1 = np.where(
1147
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale - k - l_scale)
1148
+ )[0]
1149
+ i2 = np.where(
1150
+ (j1 == 2 * nscale - 1 - k)
1151
+ * (j2 == 2 * nscale + 1 - k - l_scale)
1152
+ )[0]
1153
+ i3 = np.where(
1154
+ (j1 == 2 * nscale - 1 - k)
1155
+ * (j2 == 2 * nscale + 2 - k - l_scale)
1156
+ )[0]
797
1157
  if constant:
798
- s2[:,i0]=s2[:,i1]
799
- s2l[:,i0]=s2l[:,i1]
1158
+ s2[:, i0] = s2[:, i1]
1159
+ s2l[:, i0] = s2l[:, i1]
800
1160
  else:
801
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
802
- if len(idx[0])>0:
803
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
804
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
805
- if len(idx[0])>0:
806
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
807
-
808
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
809
- if len(idx[0])>0:
810
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
811
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
812
- if len(idx[0])>0:
813
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
814
-
815
- s1[np.isnan(s1)]=0.0
816
- p0[np.isnan(p0)]=0.0
817
- s2[np.isnan(s2)]=0.0
818
- s2l[np.isnan(s2l)]=0.0
819
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
820
-
821
- return scat(self.backend.constant(p0),self.S0,
822
- self.backend.constant(s1),
823
- self.backend.constant(s2),
824
- self.backend.constant(s2l),self.j1,self.j2,backend=self.backend)
1161
+ idx = np.where(
1162
+ (s2[:, i2] > 0)
1163
+ * (s2[:, i3] > 0)
1164
+ * (s2[:, i2] < s2th[:, i2])
1165
+ )
1166
+ if len(idx[0]) > 0:
1167
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1168
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
1169
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
1170
+ )
1171
+ idx = np.where(
1172
+ (s2[:, i1] > 0)
1173
+ * (s2[:, i2] > 0)
1174
+ * (s2[:, i1] < s2th[:, i1])
1175
+ )
1176
+ if len(idx[0]) > 0:
1177
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1178
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1179
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1180
+ )
1181
+
1182
+ idx = np.where(
1183
+ (s2l[:, i2] > 0)
1184
+ * (s2l[:, i3] > 0)
1185
+ * (s2l[:, i2] < s2lth[:, i2])
1186
+ )
1187
+ if len(idx[0]) > 0:
1188
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1189
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1190
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1191
+ )
1192
+ idx = np.where(
1193
+ (s2l[:, i1] > 0)
1194
+ * (s2l[:, i2] > 0)
1195
+ * (s2l[:, i1] < s2lth[:, i1])
1196
+ )
1197
+ if len(idx[0]) > 0:
1198
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1199
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1200
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1201
+ )
1202
+
1203
+ s1[np.isnan(s1)] = 0.0
1204
+ p0[np.isnan(p0)] = 0.0
1205
+ s2[np.isnan(s2)] = 0.0
1206
+ s2l[np.isnan(s2l)] = 0.0
1207
+ print(s1.sum(), p0.sum(), s2.sum(), s2l.sum())
1208
+
1209
+ return scat(
1210
+ self.backend.constant(p0),
1211
+ self.S0,
1212
+ self.backend.constant(s1),
1213
+ self.backend.constant(s2),
1214
+ self.backend.constant(s2l),
1215
+ self.j1,
1216
+ self.j2,
1217
+ backend=self.backend,
1218
+ )
825
1219
 
826
1220
  # ---------------------------------------------−---------
827
1221
  def flatten(self):
828
- if isinstance(self.S1,np.ndarray):
829
- return np.concatenate([self.S0.flatten(),
830
- self.S1.flatten(),
831
- self.P00.flatten(),
832
- self.S2.flatten(),
833
- self.S2L.flatten()],0)
1222
+ if isinstance(self.S1, np.ndarray):
1223
+ return np.concatenate(
1224
+ [
1225
+ self.S0.flatten(),
1226
+ self.S1.flatten(),
1227
+ self.P00.flatten(),
1228
+ self.S2.flatten(),
1229
+ self.S2L.flatten(),
1230
+ ],
1231
+ 0,
1232
+ )
834
1233
  else:
835
- return self.backend.bk_concat([self.backend.bk_flattenR(self.S0),
836
- self.backend.bk_flattenR(self.S1),
837
- self.backend.bk_flattenR(self.P00),
838
- self.backend.bk_flattenR(self.S2),
839
- self.backend.bk_flattenR(self.S2)],axis=0)
1234
+ return self.backend.bk_concat(
1235
+ [
1236
+ self.backend.bk_flattenR(self.S0),
1237
+ self.backend.bk_flattenR(self.S1),
1238
+ self.backend.bk_flattenR(self.P00),
1239
+ self.backend.bk_flattenR(self.S2),
1240
+ self.backend.bk_flattenR(self.S2),
1241
+ ],
1242
+ axis=0,
1243
+ )
840
1244
 
841
1245
  # ---------------------------------------------−---------
842
1246
  def flattenMask(self):
843
- if isinstance(self.S1,np.ndarray):
844
- tmp=np.expand_dims(np.concatenate([self.S1[0].flatten(),
845
- self.P00[0].flatten(),
846
- self.S2[0].flatten(),
847
- self.S2L[0].flatten()],0),0)
848
- for k in range(1,self.P00.shape[0]):
849
- tmp=np.concatenate([tmp,np.expand_dims(np.concatenate([self.S1[k].flatten(),
850
- self.P00[k].flatten(),
851
- self.S2[k].flatten(),
852
- self.S2L[k].flatten()],0),0)],0)
853
-
854
-
855
- return np.concatenate([tmp,np.expand_dims(self.S0,1)],1)
1247
+ if isinstance(self.S1, np.ndarray):
1248
+ tmp = np.expand_dims(
1249
+ np.concatenate(
1250
+ [
1251
+ self.S1[0].flatten(),
1252
+ self.P00[0].flatten(),
1253
+ self.S2[0].flatten(),
1254
+ self.S2L[0].flatten(),
1255
+ ],
1256
+ 0,
1257
+ ),
1258
+ 0,
1259
+ )
1260
+ for k in range(1, self.P00.shape[0]):
1261
+ tmp = np.concatenate(
1262
+ [
1263
+ tmp,
1264
+ np.expand_dims(
1265
+ np.concatenate(
1266
+ [
1267
+ self.S1[k].flatten(),
1268
+ self.P00[k].flatten(),
1269
+ self.S2[k].flatten(),
1270
+ self.S2L[k].flatten(),
1271
+ ],
1272
+ 0,
1273
+ ),
1274
+ 0,
1275
+ ),
1276
+ ],
1277
+ 0,
1278
+ )
1279
+
1280
+ return np.concatenate([tmp, np.expand_dims(self.S0, 1)], 1)
856
1281
  else:
857
- tmp=self.backend.bk_expand_dims(self.backend.bk_concat([self.backend.bk_flattenR(self.S1[0]),
858
- self.backend.bk_flattenR(self.P00[0]),
859
- self.backend.bk_flattenR(self.S2[0]),
860
- self.backend.bk_flattenR(self.S2[0])],axis=0),0)
861
- for k in range(1,self.P00.shape[0]):
862
- ltmp=self.backend.bk_expand_dims(self.backend.bk_concat([self.backend.bk_flattenR(self.S1[k]),
863
- self.backend.bk_flattenR(self.P00[k]),
864
- self.backend.bk_flattenR(self.S2[k]),
865
- self.backend.bk_flattenR(self.S2[k])],axis=0),0)
866
- tmp=self.backend.bk_concat([tmp,ltmp],0)
867
-
868
- return self.backend.bk_concat([tmp,self.backend.bk_expand_dims(self.S0,1)],1)
869
-
1282
+ tmp = self.backend.bk_expand_dims(
1283
+ self.backend.bk_concat(
1284
+ [
1285
+ self.backend.bk_flattenR(self.S1[0]),
1286
+ self.backend.bk_flattenR(self.P00[0]),
1287
+ self.backend.bk_flattenR(self.S2[0]),
1288
+ self.backend.bk_flattenR(self.S2[0]),
1289
+ ],
1290
+ axis=0,
1291
+ ),
1292
+ 0,
1293
+ )
1294
+ for k in range(1, self.P00.shape[0]):
1295
+ ltmp = self.backend.bk_expand_dims(
1296
+ self.backend.bk_concat(
1297
+ [
1298
+ self.backend.bk_flattenR(self.S1[k]),
1299
+ self.backend.bk_flattenR(self.P00[k]),
1300
+ self.backend.bk_flattenR(self.S2[k]),
1301
+ self.backend.bk_flattenR(self.S2[k]),
1302
+ ],
1303
+ axis=0,
1304
+ ),
1305
+ 0,
1306
+ )
1307
+ tmp = self.backend.bk_concat([tmp, ltmp], 0)
1308
+
1309
+ return self.backend.bk_concat(
1310
+ [tmp, self.backend.bk_expand_dims(self.S0, 1)], 1
1311
+ )
1312
+
870
1313
  # ---------------------------------------------−---------
871
- def model(self,i__y,add=0,dx=3,dell=2,weigth=None,inverse=False):
1314
+ def model(self, i__y, add=0, dx=3, dell=2, weigth=None, inverse=False):
872
1315
 
873
- if i__y.shape[0]<dx+1:
874
- l__dx=i__y.shape[0]-1
1316
+ if i__y.shape[0] < dx + 1:
1317
+ l__dx = i__y.shape[0] - 1
875
1318
  else:
876
- l__dx=dx
1319
+ l__dx = dx
877
1320
 
878
- if i__y.shape[0]<dell:
879
- l__dell=0
1321
+ if i__y.shape[0] < dell:
1322
+ l__dell = 0
880
1323
  else:
881
- l__dell=dell
1324
+ l__dell = dell
882
1325
 
883
- if l__dx<2:
884
- res=np.zeros([i__y.shape[0]+add])
1326
+ if l__dx < 2:
1327
+ res = np.zeros([i__y.shape[0] + add])
885
1328
  if inverse:
886
- res[:-add]=i__y
1329
+ res[:-add] = i__y
887
1330
  else:
888
- res[add:]=i__y[0:]
1331
+ res[add:] = i__y[0:]
889
1332
  return res
890
1333
 
891
1334
  if weigth is None:
892
- w=2**(np.arange(l__dx))
1335
+ w = 2 ** (np.arange(l__dx))
893
1336
  else:
894
1337
  if not inverse:
895
- w=weigth[0:l__dx]
1338
+ w = weigth[0:l__dx]
896
1339
  else:
897
- w=weigth[-l__dx:]
1340
+ w = weigth[-l__dx:]
898
1341
 
899
- x=np.arange(l__dx)+1
1342
+ x = np.arange(l__dx) + 1
900
1343
  if not inverse:
901
- y=np.log(i__y[1:l__dx+1])
1344
+ y = np.log(i__y[1 : l__dx + 1])
902
1345
  else:
903
- y=np.log(i__y[-(l__dx+1):-1])
1346
+ y = np.log(i__y[-(l__dx + 1) : -1])
904
1347
 
905
- r=np.polyfit(x,y,1,w=w)
1348
+ r = np.polyfit(x, y, 1, w=w)
906
1349
 
907
1350
  if inverse:
908
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-1)+r[1])
909
- res[:-(l__dell+add)]=i__y[:-l__dell]
1351
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - 1) + r[1])
1352
+ res[: -(l__dell + add)] = i__y[:-l__dell]
910
1353
  else:
911
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-add)+r[1])
912
- res[l__dell+add:]=i__y[l__dell:]
1354
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - add) + r[1])
1355
+ res[l__dell + add :] = i__y[l__dell:]
913
1356
  return res
914
1357
 
915
- def findn(self,n):
916
- d=np.sqrt(1+8*n)
917
- return int((d-1)/2)
1358
+ def findn(self, n):
1359
+ d = np.sqrt(1 + 8 * n)
1360
+ return int((d - 1) / 2)
918
1361
 
919
- def findidx(self,s2):
920
- i1=np.zeros([s2.shape[1]],dtype='int')
921
- i2=np.zeros([s2.shape[1]],dtype='int')
922
- n=0
1362
+ def findidx(self, s2):
1363
+ i1 = np.zeros([s2.shape[1]], dtype="int")
1364
+ i2 = np.zeros([s2.shape[1]], dtype="int")
1365
+ n = 0
923
1366
  for k in range(self.findn(s2.shape[1])):
924
- i1[n:n+k+1]=np.arange(k+1)
925
- i2[n:n+k+1]=k
926
- n=n+k+1
927
- return i1,i2
928
-
929
- def extrapol_s2(self,add,lnorm=1):
930
- if lnorm==1:
931
- s2=self.S2.numpy()
932
- if lnorm==2:
933
- s2=self.S2L.numpy()
934
- i1,i2=self.findidx(s2)
935
-
936
- so2=np.zeros([s2.shape[0],(self.findn(s2.shape[1])+add)*(self.findn(s2.shape[1])+add+1)//2,s2.shape[2],s2.shape[3]])
937
- oi1,oi2=self.findidx(so2)
938
- for l in range(s2.shape[0]):
1367
+ i1[n : n + k + 1] = np.arange(k + 1)
1368
+ i2[n : n + k + 1] = k
1369
+ n = n + k + 1
1370
+ return i1, i2
1371
+
1372
+ def extrapol_s2(self, add, lnorm=1):
1373
+ if lnorm == 1:
1374
+ s2 = self.S2.numpy()
1375
+ if lnorm == 2:
1376
+ s2 = self.S2L.numpy()
1377
+ i1, i2 = self.findidx(s2)
1378
+
1379
+ so2 = np.zeros(
1380
+ [
1381
+ s2.shape[0],
1382
+ (self.findn(s2.shape[1]) + add)
1383
+ * (self.findn(s2.shape[1]) + add + 1)
1384
+ // 2,
1385
+ s2.shape[2],
1386
+ s2.shape[3],
1387
+ ]
1388
+ )
1389
+ oi1, oi2 = self.findidx(so2)
1390
+ for l_batch in range(s2.shape[0]):
939
1391
  for k in range(self.findn(s2.shape[1])):
940
1392
  for i in range(s2.shape[2]):
941
1393
  for j in range(s2.shape[3]):
942
- tmp=self.model(s2[l,i2==k,i,j],dx=4,dell=1,add=add,weigth=np.array([1,2,2,2]))
943
- tmp[np.isnan(tmp)]=0.0
944
- so2[l,oi2==k+add,i,j]=tmp
945
-
946
-
947
- for l in range(s2.shape[0]):
948
- for k in range(add+1,-1,-1):
949
- lidx=np.where(oi2-oi1==k)[0]
950
- lidx2=np.where(oi2-oi1==k+1)[0]
1394
+ tmp = self.model(
1395
+ s2[l_batch, i2 == k, i, j],
1396
+ dx=4,
1397
+ dell=1,
1398
+ add=add,
1399
+ weigth=np.array([1, 2, 2, 2]),
1400
+ )
1401
+ tmp[np.isnan(tmp)] = 0.0
1402
+ so2[l_batch, oi2 == k + add, i, j] = tmp
1403
+
1404
+ for l_batch in range(s2.shape[0]):
1405
+ for k in range(add + 1, -1, -1):
1406
+ lidx = np.where(oi2 - oi1 == k)[0]
1407
+ lidx2 = np.where(oi2 - oi1 == k + 1)[0]
951
1408
  for i in range(s2.shape[2]):
952
1409
  for j in range(s2.shape[3]):
953
- so2[l,lidx[0:add+2-k],i,j]=so2[l,lidx2[0:add+2-k],i,j]
1410
+ so2[l_batch, lidx[0 : add + 2 - k], i, j] = so2[
1411
+ l_batch, lidx2[0 : add + 2 - k], i, j
1412
+ ]
954
1413
 
955
- return(so2)
1414
+ return so2
956
1415
 
957
- def extrapol_s1(self,i_s1,add):
958
- s1=i_s1.numpy()
959
- so1=np.zeros([s1.shape[0],s1.shape[1]+add,s1.shape[2]])
1416
+ def extrapol_s1(self, i_s1, add):
1417
+ s1 = i_s1.numpy()
1418
+ so1 = np.zeros([s1.shape[0], s1.shape[1] + add, s1.shape[2]])
960
1419
  for k in range(s1.shape[0]):
961
1420
  for i in range(s1.shape[2]):
962
- so1[k,:,i]=self.model(s1[k,:,i],dx=4,dell=1,add=add)
963
- so1[k,np.isnan(so1[k,:,i]),i]=0.0
1421
+ so1[k, :, i] = self.model(s1[k, :, i], dx=4, dell=1, add=add)
1422
+ so1[k, np.isnan(so1[k, :, i]), i] = 0.0
964
1423
  return so1
965
1424
 
966
- def extrapol(self,add):
967
- return scat(self.extrapol_s1(self.P00,add), \
968
- self.S0, \
969
- self.extrapol_s1(self.S1,add), \
970
- self.extrapol_s2(add,lnorm=1), \
971
- self.extrapol_s2(add,lnorm=2),self.j1,self.j2,backend=self.backend)
972
-
973
-
974
-
975
-
976
-
1425
+ def extrapol(self, add):
1426
+ return scat(
1427
+ self.extrapol_s1(self.P00, add),
1428
+ self.S0,
1429
+ self.extrapol_s1(self.S1, add),
1430
+ self.extrapol_s2(add, lnorm=1),
1431
+ self.extrapol_s2(add, lnorm=2),
1432
+ self.j1,
1433
+ self.j2,
1434
+ backend=self.backend,
1435
+ )
1436
+
1437
+
977
1438
  class funct(FOC.FoCUS):
978
-
979
- def fill(self,im,nullval=hp.UNSEEN):
980
- return self.fill_healpy(im,nullval=nullval)
981
-
982
- def moments(self,list_scat):
983
- S0=None
1439
+
1440
+ def fill(self, im, nullval=hp.UNSEEN):
1441
+ return self.fill_healpy(im, nullval=nullval)
1442
+
1443
+ def moments(self, list_scat):
1444
+ S0 = None
984
1445
  for k in list_scat:
985
- tmp=list_scat[k]
986
- nS0=np.expand_dims(tmp.S0.numpy(),0)
987
- nP00=np.expand_dims(tmp.P00.numpy(),0)
988
- nS1=np.expand_dims(tmp.S1.numpy(),0)
989
- nS2=np.expand_dims(tmp.S2.numpy(),0)
990
- nS2L=np.expand_dims(tmp.S2L.numpy(),0)
991
-
1446
+ tmp = list_scat[k]
1447
+ nS0 = np.expand_dims(tmp.S0.numpy(), 0)
1448
+ nP00 = np.expand_dims(tmp.P00.numpy(), 0)
1449
+ nS1 = np.expand_dims(tmp.S1.numpy(), 0)
1450
+ nS2 = np.expand_dims(tmp.S2.numpy(), 0)
1451
+ nS2L = np.expand_dims(tmp.S2L.numpy(), 0)
1452
+
992
1453
  if S0 is None:
993
- S0=nS0
994
- P00=nP00
995
- S1=nS1
996
- S2=nS2
997
- S2L=nS2L
1454
+ S0 = nS0
1455
+ P00 = nP00
1456
+ S1 = nS1
1457
+ S2 = nS2
1458
+ S2L = nS2L
998
1459
  else:
999
- S0=np.concatenate([S0,nS0],0)
1000
- P00=np.concatenate([P00,nP00],0)
1001
- S1=np.concatenate([S1,nS1],0)
1002
- S2=np.concatenate([S2,nS2],0)
1003
- S2L=np.concatenate([S2L,nS2L],0)
1004
-
1005
- sS0=np.std(S0,0)
1006
- sP00=np.std(P00,0)
1007
- sS1=np.std(S1,0)
1008
- sS2=np.std(S2,0)
1009
- sS2L=np.std(S2L,0)
1010
-
1011
- mS0=np.mean(S0,0)
1012
- mP00=np.mean(P00,0)
1013
- mS1=np.mean(S1,0)
1014
- mS2=np.mean(S2,0)
1015
- mS2L=np.mean(S2L,0)
1016
-
1017
- return scat(mP00,mS0,mS1,mS2,mS2L,tmp.j1,tmp.j2,backend=self.backend), \
1018
- scat(sP00,sS0,sS1,sS2,sS2L,tmp.j1,tmp.j2,backend=self.backend)
1019
-
1020
- def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,calc_var=False):
1460
+ S0 = np.concatenate([S0, nS0], 0)
1461
+ P00 = np.concatenate([P00, nP00], 0)
1462
+ S1 = np.concatenate([S1, nS1], 0)
1463
+ S2 = np.concatenate([S2, nS2], 0)
1464
+ S2L = np.concatenate([S2L, nS2L], 0)
1465
+
1466
+ sS0 = np.std(S0, 0)
1467
+ sP00 = np.std(P00, 0)
1468
+ sS1 = np.std(S1, 0)
1469
+ sS2 = np.std(S2, 0)
1470
+ sS2L = np.std(S2L, 0)
1471
+
1472
+ mS0 = np.mean(S0, 0)
1473
+ mP00 = np.mean(P00, 0)
1474
+ mS1 = np.mean(S1, 0)
1475
+ mS2 = np.mean(S2, 0)
1476
+ mS2L = np.mean(S2L, 0)
1477
+
1478
+ return scat(
1479
+ mP00, mS0, mS1, mS2, mS2L, tmp.j1, tmp.j2, backend=self.backend
1480
+ ), scat(sP00, sS0, sS1, sS2, sS2L, tmp.j1, tmp.j2, backend=self.backend)
1481
+
1482
+ def eval(
1483
+ self,
1484
+ image1,
1485
+ image2=None,
1486
+ mask=None,
1487
+ Auto=True,
1488
+ s0_off=1e-6,
1489
+ calc_var=False,
1490
+ norm=None,
1491
+ ):
1021
1492
  # Check input consistency
1022
1493
  if image2 is not None:
1023
- if list(image1.shape)!=list(image2.shape):
1024
- print('The two input image should have the same size to eval Scattering')
1025
-
1026
- exit(0)
1494
+ if list(image1.shape) != list(image2.shape):
1495
+ print(
1496
+ "The two input image should have the same size to eval Scattering"
1497
+ )
1498
+
1499
+ return None
1027
1500
  if mask is not None:
1028
- if list(image1.shape)!=list(mask.shape)[1:]:
1029
- print('The mask should have the same size than the input image to eval Scattering')
1030
- print(image1.shape,mask.shape)
1031
- exit(0)
1032
- if self.use_2D and len(image1.shape)<2:
1033
- print('To work with 2D scattering transform, two dimension is needed, input map has only on dimension')
1034
- exit(0)
1035
-
1036
-
1501
+ if list(image1.shape) != list(mask.shape)[1:]:
1502
+ print(
1503
+ "The mask should have the same size than the input image to eval Scattering"
1504
+ )
1505
+ print("Image shape ", image1.shape, "Mask shape ", mask.shape)
1506
+ return None
1507
+ if self.use_2D and len(image1.shape) < 2:
1508
+ print(
1509
+ "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
1510
+ )
1511
+ return None
1512
+
1037
1513
  ### AUTO OR CROSS
1038
1514
  cross = False
1039
1515
  if image2 is not None:
1040
1516
  cross = True
1041
- all_cross=not Auto
1042
- else:
1043
- all_cross=False
1044
-
1517
+
1045
1518
  # Check if image1 is [Npix] or [Nbatch,Npix]
1046
- axis=1
1047
-
1519
+ axis = 1
1520
+
1048
1521
  # determine jmax and nside corresponding to the input map
1049
1522
  im_shape = image1.shape
1050
1523
  if self.use_2D:
1051
- if len(image1.shape)==2:
1052
- nside=np.min([im_shape[0],im_shape[1]])
1053
- npix = im_shape[0]*im_shape[1] # Number of pixels
1054
- x1=im_shape[0]
1055
- x2=im_shape[1]
1524
+ if len(image1.shape) == 2:
1525
+ nside = np.min([im_shape[0], im_shape[1]])
1526
+ npix = im_shape[0] * im_shape[1] # Number of pixels
1056
1527
  else:
1057
- nside=np.min([im_shape[1],im_shape[2]])
1058
- npix = im_shape[1]*im_shape[2] # Number of pixels
1059
- x1=im_shape[1]
1060
- x2=im_shape[2]
1061
- jmax = int(np.log(nside-self.KERNELSZ) / np.log(2)) # Number of j scales
1528
+ nside = np.min([im_shape[1], im_shape[2]])
1529
+ npix = im_shape[1] * im_shape[2] # Number of pixels
1530
+ jmax = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
1062
1531
  else:
1063
- if len(image1.shape)==2:
1532
+ if len(image1.shape) == 2:
1064
1533
  npix = int(im_shape[1]) # Number of pixels
1065
1534
  else:
1066
1535
  npix = int(im_shape[0]) # Number of pixels
1067
1536
 
1068
- nside=int(np.sqrt(npix//12))
1069
-
1070
- jmax=int(np.log(nside)/np.log(2)) #-self.OSTEP
1537
+ nside = int(np.sqrt(npix // 12))
1538
+
1539
+ jmax = int(np.log(nside) / np.log(2)) # -self.OSTEP
1071
1540
 
1072
1541
  ### LOCAL VARIABLES (IMAGES and MASK)
1073
1542
  # Check if image1 is [Npix] or [Nbatch,Npix]
1074
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1543
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1075
1544
  # image1 is [Nbatch, Npix]
1076
- I1 = self.backend.bk_cast(self.backend.bk_expand_dims(image1,0)) # Local image1 [Nbatch, Npix]
1545
+ I1 = self.backend.bk_cast(
1546
+ self.backend.bk_expand_dims(image1, 0)
1547
+ ) # Local image1 [Nbatch, Npix]
1077
1548
  if cross:
1078
- I2 = self.backend.bk_cast(self.backend.bk_expand_dims(image2,0)) # Local image2 [Nbatch, Npix]
1549
+ I2 = self.backend.bk_cast(
1550
+ self.backend.bk_expand_dims(image2, 0)
1551
+ ) # Local image2 [Nbatch, Npix]
1079
1552
  else:
1080
- I1=self.backend.bk_cast(image1)
1553
+ I1 = self.backend.bk_cast(image1)
1081
1554
  if cross:
1082
- I2=self.backend.bk_cast(image2)
1083
-
1555
+ I2 = self.backend.bk_cast(image2)
1556
+
1084
1557
  # self.mask is [Nmask, Npix]
1085
1558
  if mask is None:
1086
1559
  if self.use_2D:
1087
- vmask = self.backend.bk_ones([1, I1.shape[axis], I1.shape[axis+1]],dtype=self.all_type)
1560
+ vmask = self.backend.bk_ones(
1561
+ [1, I1.shape[axis], I1.shape[axis + 1]], dtype=self.all_type
1562
+ )
1088
1563
  else:
1089
1564
  vmask = self.backend.bk_ones([1, I1.shape[axis]], dtype=self.all_type)
1090
1565
  else:
1091
1566
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
1092
1567
 
1093
- if self.KERNELSZ>3:
1094
- if self.KERNELSZ==5:
1568
+ if self.KERNELSZ > 3:
1569
+ if self.KERNELSZ == 5:
1095
1570
  # if the kernel size is bigger than 3 increase the binning before smoothing
1096
1571
  if self.use_2D:
1097
- l_image1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1098
- vmask=self.up_grade(vmask,I1.shape[axis]*2,axis=1,nouty=I1.shape[axis+1]*2)
1572
+ l_image1 = self.up_grade(
1573
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
1574
+ )
1575
+ vmask = self.up_grade(
1576
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
1577
+ )
1099
1578
  else:
1100
- l_image1=self.up_grade(I1,nside*2,axis=axis)
1101
- vmask=self.up_grade(vmask,nside*2,axis=1)
1102
-
1579
+ l_image1 = self.up_grade(I1, nside * 2, axis=axis)
1580
+ vmask = self.up_grade(vmask, nside * 2, axis=1)
1581
+
1103
1582
  if cross:
1104
1583
  if self.use_2D:
1105
- l_image2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1584
+ l_image2 = self.up_grade(
1585
+ I2,
1586
+ I2.shape[axis] * 2,
1587
+ axis=axis,
1588
+ nouty=I2.shape[axis + 1] * 2,
1589
+ )
1106
1590
  else:
1107
- l_image2=self.up_grade(I2,nside*2,axis=axis)
1591
+ l_image2 = self.up_grade(I2, nside * 2, axis=axis)
1108
1592
  else:
1109
1593
  # if the kernel size is bigger than 3 increase the binning before smoothing
1110
1594
  if self.use_2D:
1111
- print(axis,image1.shape)
1112
- l_image1=self.up_grade(l_image1,I1.shape[axis]*4,axis=axis,nouty=I1.shape[axis+1]*4)
1113
- vmask=self.up_grade(vmask,I1.shape[axis]*4,axis=1,nouty=I1.shape[axis+1]*4)
1595
+ l_image1 = self.up_grade(
1596
+ l_image1,
1597
+ I1.shape[axis] * 4,
1598
+ axis=axis,
1599
+ nouty=I1.shape[axis + 1] * 4,
1600
+ )
1601
+ vmask = self.up_grade(
1602
+ vmask, I1.shape[axis] * 4, axis=1, nouty=I1.shape[axis + 1] * 4
1603
+ )
1114
1604
  else:
1115
- l_image1=self.up_grade(l_image1,nside*4,axis=axis)
1116
- vmask=self.up_grade(vmask,nside*4,axis=1)
1117
-
1605
+ l_image1 = self.up_grade(l_image1, nside * 4, axis=axis)
1606
+ vmask = self.up_grade(vmask, nside * 4, axis=1)
1607
+
1118
1608
  if cross:
1119
1609
  if self.use_2D:
1120
- l_image2=self.up_grade(l_image2,I2.shape[axis]*4,axis=axis,nouty=I2.shape[axis+1]*4)
1610
+ l_image2 = self.up_grade(
1611
+ l_image2,
1612
+ I2.shape[axis] * 4,
1613
+ axis=axis,
1614
+ nouty=I2.shape[axis + 1] * 4,
1615
+ )
1121
1616
  else:
1122
- l_image2=self.up_grade(l_image2,nside*4,axis=axis)
1617
+ l_image2 = self.up_grade(l_image2, nside * 4, axis=axis)
1123
1618
  else:
1124
- l_image1=I1
1619
+ l_image1 = I1
1125
1620
  if cross:
1126
- l_image2=I2
1621
+ l_image2 = I2
1127
1622
 
1128
1623
  if calc_var:
1129
- s0,vs0 = self.masked_mean(l_image1,vmask,axis=axis,calc_var=True)
1130
- s0=s0+s0_off
1624
+ s0, vs0 = self.masked_mean(l_image1, vmask, axis=axis, calc_var=True)
1625
+ s0 = s0 + s0_off
1131
1626
  else:
1132
- s0 = self.masked_mean(l_image1,vmask,axis=axis)+s0_off
1133
-
1134
- if cross and Auto==False:
1627
+ s0 = self.masked_mean(l_image1, vmask, axis=axis) + s0_off
1628
+
1629
+ if cross and not Auto:
1135
1630
  if calc_var:
1136
- s02,vs02=self.masked_mean(l_image2,vmask,axis=axis,calc_var=True)
1631
+ s02, vs02 = self.masked_mean(l_image2, vmask, axis=axis, calc_var=True)
1137
1632
  else:
1138
- s02=self.masked_mean(l_image2,vmask,axis=axis)
1139
-
1140
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1633
+ s02 = self.masked_mean(l_image2, vmask, axis=axis)
1634
+
1635
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1141
1636
  if self.backend.bk_is_complex(s0):
1142
- s0 = self.backend.bk_complex(s0,s02+s0_off)
1637
+ s0 = self.backend.bk_complex(s0, s02 + s0_off)
1143
1638
  if calc_var:
1144
- vs0 = self.backend.bk_complex(vs0,vs02)
1639
+ vs0 = self.backend.bk_complex(vs0, vs02)
1145
1640
  else:
1146
- s0 = self.backend.bk_concat([s0,s02],axis=0)
1641
+ s0 = self.backend.bk_concat([s0, s02], axis=0)
1147
1642
  if calc_var:
1148
- vs0 = self.backend.bk_concat([vs0,vs02],axis=0)
1643
+ vs0 = self.backend.bk_concat([vs0, vs02], axis=0)
1149
1644
  else:
1150
1645
  if self.backend.bk_is_complex(s0):
1151
- s0 = self.backend.bk_complex(s0,s02+s0_off)
1646
+ s0 = self.backend.bk_complex(s0, s02 + s0_off)
1152
1647
  if calc_var:
1153
- vs0 = self.backend.bk_complex(vs0,vs02)
1648
+ vs0 = self.backend.bk_complex(vs0, vs02)
1154
1649
  else:
1155
- s0 = self.backend.bk_concat([s0,s02],axis=0)
1650
+ s0 = self.backend.bk_concat([s0, s02], axis=0)
1156
1651
  if calc_var:
1157
- vs0 = self.backend.bk_concat([vs0,vs02],axis=0)
1158
-
1159
- s1=None
1160
- s2=None
1161
- s2l=None
1162
- p00=None
1163
- s2j1=None
1164
- s2j2=None
1165
-
1166
- l2_image=None
1167
- l2_image_imag=None
1168
-
1652
+ vs0 = self.backend.bk_concat([vs0, vs02], axis=0)
1653
+
1654
+ s1 = None
1655
+ s2 = None
1656
+ s2l = None
1657
+ p00 = None
1658
+ s2j1 = None
1659
+ s2j2 = None
1660
+ l2_image = None
1169
1661
  for j1 in range(jmax):
1170
- if j1<jmax-self.OSTEP: # stop to add scales
1662
+ if j1 < jmax - self.OSTEP: # stop to add scales
1171
1663
  # Convol image along the axis defined by 'axis' using the wavelet defined at
1172
1664
  # the foscat initialisation
1173
- #c_image_real is [....,Npix_j1,....,Norient]
1174
- c_image1=self.convol(l_image1,axis=axis)
1665
+ # c_image_real is [....,Npix_j1,....,Norient]
1666
+ c_image1 = self.convol(l_image1, axis=axis)
1175
1667
  if cross:
1176
- c_image2=self.convol(l_image2,axis=axis)
1668
+ c_image2 = self.convol(l_image2, axis=axis)
1177
1669
  else:
1178
- c_image2=c_image1
1670
+ c_image2 = c_image1
1179
1671
 
1180
1672
  # Compute (a+ib)*(a+ib)* the last c_image column is the real and imaginary part
1181
- conj=c_image1*self.backend.bk_conjugate(c_image2)
1182
-
1673
+ conj = c_image1 * self.backend.bk_conjugate(c_image2)
1674
+
1183
1675
  if Auto:
1184
- conj=self.backend.bk_real(conj)
1676
+ conj = self.backend.bk_real(conj)
1185
1677
 
1186
1678
  # Compute l_p00 [....,....,Nmask,j1,Norient]
1187
1679
  if calc_var:
1188
- l_p00,l_vp00 = self.masked_mean(conj,vmask,axis=axis,rank=j1,calc_var=True)
1189
- l_p00 = self.backend.bk_expand_dims(l_p00,-2)
1190
- l_vp00 = self.backend.bk_expand_dims(l_vp00,-2)
1680
+ l_p00, l_vp00 = self.masked_mean(
1681
+ conj, vmask, axis=axis, rank=j1, calc_var=True
1682
+ )
1683
+ l_p00 = self.backend.bk_expand_dims(l_p00, -2)
1684
+ l_vp00 = self.backend.bk_expand_dims(l_vp00, -2)
1191
1685
  else:
1192
- l_p00 = self.masked_mean(conj,vmask,axis=axis,rank=j1)
1193
- l_p00 = self.backend.bk_expand_dims(l_p00,-2)
1686
+ l_p00 = self.masked_mean(conj, vmask, axis=axis, rank=j1)
1687
+ l_p00 = self.backend.bk_expand_dims(l_p00, -2)
1194
1688
 
1195
- conj=self.backend.bk_L1(conj)
1689
+ conj = self.backend.bk_L1(conj)
1196
1690
 
1197
- # Compute l_s1 [....,....,Nmask,1,Norient]
1691
+ # Compute l_s1 [....,....,Nmask,1,Norient]
1198
1692
  if calc_var:
1199
- l_s1,l_vs1 = self.masked_mean(conj,vmask,axis=axis,rank=j1,calc_var=True)
1200
- l_s1 =self.backend.bk_expand_dims(l_s1,-2)
1201
- l_vs1 =self.backend.bk_expand_dims(l_vs1,-2)
1693
+ l_s1, l_vs1 = self.masked_mean(
1694
+ conj, vmask, axis=axis, rank=j1, calc_var=True
1695
+ )
1696
+ l_s1 = self.backend.bk_expand_dims(l_s1, -2)
1697
+ l_vs1 = self.backend.bk_expand_dims(l_vs1, -2)
1202
1698
  else:
1203
- l_s1 = self.backend.bk_expand_dims(self.masked_mean(conj,vmask,axis=axis,rank=j1),-2)
1699
+ l_s1 = self.backend.bk_expand_dims(
1700
+ self.masked_mean(conj, vmask, axis=axis, rank=j1), -2
1701
+ )
1204
1702
 
1205
- # Concat S1,P00 [....,....,Nmask,j1,Norient]
1703
+ # Concat S1,P00 [....,....,Nmask,j1,Norient]
1206
1704
  if s1 is None:
1207
- s1=l_s1
1208
- p00=l_p00
1705
+ s1 = l_s1
1706
+ p00 = l_p00
1209
1707
  if calc_var:
1210
- vs1=l_vs1
1211
- vp00=l_vp00
1708
+ vs1 = l_vs1
1709
+ vp00 = l_vp00
1212
1710
  else:
1213
- s1=self.backend.bk_concat([s1,l_s1],axis=-2)
1214
- p00=self.backend.bk_concat([p00,l_p00],axis=-2)
1711
+ s1 = self.backend.bk_concat([s1, l_s1], axis=-2)
1712
+ p00 = self.backend.bk_concat([p00, l_p00], axis=-2)
1215
1713
  if calc_var:
1216
- vs1=self.backend.bk_concat([vs1,l_vs1],axis=-2)
1217
- vp00=self.backend.bk_concat([vp00,l_vp00],axis=-2)
1714
+ vs1 = self.backend.bk_concat([vs1, l_vs1], axis=-2)
1715
+ vp00 = self.backend.bk_concat([vp00, l_vp00], axis=-2)
1218
1716
 
1219
1717
  # Concat l2_image [....,j1,Npix_j1,,....,Norient]
1220
1718
  if l2_image is None:
1221
1719
  if self.use_2D:
1222
- l2_image=self.backend.bk_expand_dims(conj,axis=-4)
1720
+ l2_image = self.backend.bk_expand_dims(conj, axis=-4)
1223
1721
  else:
1224
- l2_image=self.backend.bk_expand_dims(conj,axis=-3)
1722
+ l2_image = self.backend.bk_expand_dims(conj, axis=-3)
1225
1723
  else:
1226
1724
  if self.use_2D:
1227
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=-4),l2_image],axis=-4)
1725
+ l2_image = self.backend.bk_concat(
1726
+ [self.backend.bk_expand_dims(conj, axis=-4), l2_image],
1727
+ axis=-4,
1728
+ )
1228
1729
  else:
1229
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=-3),l2_image],axis=-3)
1730
+ l2_image = self.backend.bk_concat(
1731
+ [self.backend.bk_expand_dims(conj, axis=-3), l2_image],
1732
+ axis=-3,
1733
+ )
1230
1734
 
1231
1735
  # Convol l2_image [....,Npix_j1,j1,....,Norient,Norient]
1232
- c2_image=self.convol(self.backend.bk_relu(l2_image),axis=axis+1)
1736
+ c2_image = self.convol(self.backend.bk_relu(l2_image), axis=axis + 1)
1233
1737
 
1234
- conj2p=c2_image*self.backend.bk_conjugate(c2_image)
1235
- conj2pl1=self.backend.bk_L1(conj2p)
1738
+ conj2p = c2_image * self.backend.bk_conjugate(c2_image)
1739
+ conj2pl1 = self.backend.bk_L1(conj2p)
1236
1740
 
1237
1741
  if Auto:
1238
- conj2p=self.backend.bk_real(conj2p)
1239
- conj2pl1=self.backend.bk_real(conj2pl1)
1742
+ conj2p = self.backend.bk_real(conj2p)
1743
+ conj2pl1 = self.backend.bk_real(conj2pl1)
1240
1744
 
1241
- c2_image=self.convol(self.backend.bk_relu(-l2_image),axis=axis+1)
1745
+ c2_image = self.convol(self.backend.bk_relu(-l2_image), axis=axis + 1)
1242
1746
 
1243
- conj2m=c2_image*self.backend.bk_conjugate(c2_image)
1244
- conj2ml1=self.backend.bk_L1(conj2m)
1747
+ conj2m = c2_image * self.backend.bk_conjugate(c2_image)
1748
+ conj2ml1 = self.backend.bk_L1(conj2m)
1245
1749
 
1246
1750
  if Auto:
1247
- conj2m=self.backend.bk_real(conj2m)
1248
- conj2ml1=self.backend.bk_real(conj2ml1)
1249
-
1751
+ conj2m = self.backend.bk_real(conj2m)
1752
+ conj2ml1 = self.backend.bk_real(conj2ml1)
1753
+
1250
1754
  # Convol l_s2 [....,....,Nmask,j1,Norient,Norient]
1251
1755
  if calc_var:
1252
- l_s2,l_vs2 = self.masked_mean(conj2p-conj2m,vmask,axis=axis+1,rank=j1,calc_var=True)
1253
- l_s2l1,l_vs2l1 = self.masked_mean(conj2pl1-conj2ml1,vmask,axis=axis+1,rank=j1,calc_var=True)
1756
+ l_s2, l_vs2 = self.masked_mean(
1757
+ conj2p - conj2m, vmask, axis=axis + 1, rank=j1, calc_var=True
1758
+ )
1759
+ l_s2l1, l_vs2l1 = self.masked_mean(
1760
+ conj2pl1 - conj2ml1, vmask, axis=axis + 1, rank=j1, calc_var=True
1761
+ )
1254
1762
  else:
1255
- l_s2 = self.masked_mean(conj2p-conj2m,vmask,axis=axis+1,rank=j1)
1256
- l_s2l1 = self.masked_mean(conj2pl1-conj2ml1,vmask,axis=axis+1,rank=j1)
1763
+ l_s2 = self.masked_mean(conj2p - conj2m, vmask, axis=axis + 1, rank=j1)
1764
+ l_s2l1 = self.masked_mean(
1765
+ conj2pl1 - conj2ml1, vmask, axis=axis + 1, rank=j1
1766
+ )
1257
1767
 
1258
1768
  # Concat l_s2 [....,....,Nmask,j1*(j1+1)/2,Norient,Norient]
1259
1769
  if s2 is None:
1260
- s2l=l_s2
1261
- s2=l_s2l1
1770
+ s2l = l_s2
1771
+ s2 = l_s2l1
1262
1772
  if calc_var:
1263
- vs2l=l_vs2
1264
- vs2=l_vs2l1
1265
-
1266
- s2j1=np.arange(l_s2.shape[axis+1],dtype='int')
1267
- s2j2=j1*np.ones(l_s2.shape[axis+1],dtype='int')
1773
+ vs2l = l_vs2
1774
+ vs2 = l_vs2l1
1775
+
1776
+ s2j1 = np.arange(l_s2.shape[axis + 1], dtype="int")
1777
+ s2j2 = j1 * np.ones(l_s2.shape[axis + 1], dtype="int")
1268
1778
  else:
1269
- s2=self.backend.bk_concat([s2,l_s2l1],axis=-3)
1270
- s2l=self.backend.bk_concat([s2l,l_s2],axis=-3)
1779
+ s2 = self.backend.bk_concat([s2, l_s2l1], axis=-3)
1780
+ s2l = self.backend.bk_concat([s2l, l_s2], axis=-3)
1271
1781
  if calc_var:
1272
- vs2=self.backend.bk_concat([vs2,l_vs2l1],axis=-3)
1273
- vs2l=self.backend.bk_concat([vs2l,l_vs2],axis=-3)
1274
-
1275
- s2j1=np.concatenate([s2j1,np.arange(l_s2.shape[axis+1],dtype='int')],0)
1276
- s2j2=np.concatenate([s2j2,j1*np.ones(l_s2.shape[axis+1],dtype='int')],0)
1277
-
1278
- if j1!=jmax-1:
1279
- # Rescale vmask [Nmask,Npix_j1//4]
1280
- vmask = self.smooth(vmask,axis=1)
1281
- vmask = self.ud_grade_2(vmask,axis=1)
1782
+ vs2 = self.backend.bk_concat([vs2, l_vs2l1], axis=-3)
1783
+ vs2l = self.backend.bk_concat([vs2l, l_vs2], axis=-3)
1784
+
1785
+ s2j1 = np.concatenate(
1786
+ [s2j1, np.arange(l_s2.shape[axis + 1], dtype="int")], 0
1787
+ )
1788
+ s2j2 = np.concatenate(
1789
+ [s2j2, j1 * np.ones(l_s2.shape[axis + 1], dtype="int")], 0
1790
+ )
1791
+
1792
+ if j1 != jmax - 1:
1793
+ # Rescale vmask [Nmask,Npix_j1//4]
1794
+ vmask = self.smooth(vmask, axis=1)
1795
+ vmask = self.ud_grade_2(vmask, axis=1)
1282
1796
  if self.mask_thres is not None:
1283
- vmask = self.backend.bk_threshold(vmask,self.mask_thres)
1797
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
1284
1798
 
1285
- # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1286
- l2_image = self.smooth(l2_image,axis=axis+1)
1287
- l2_image = self.ud_grade_2(l2_image,axis=axis+1)
1799
+ # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1800
+ l2_image = self.smooth(l2_image, axis=axis + 1)
1801
+ l2_image = self.ud_grade_2(l2_image, axis=axis + 1)
1288
1802
 
1289
- # Rescale l_image [....,Npix_j1//4,....]
1290
- l_image1 = self.smooth(l_image1,axis=axis)
1291
- l_image1 = self.ud_grade_2(l_image1,axis=axis)
1803
+ # Rescale l_image [....,Npix_j1//4,....]
1804
+ l_image1 = self.smooth(l_image1, axis=axis)
1805
+ l_image1 = self.ud_grade_2(l_image1, axis=axis)
1292
1806
  if cross:
1293
- l_image2 = self.smooth(l_image2,axis=axis)
1294
- l_image2 = self.ud_grade_2(l_image2,axis=axis)
1295
-
1296
-
1297
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1298
- sc_ret=scat(p00[0],s0[0],s1[0],s2[0],s2l[0],s2j1,s2j2,cross=cross,backend=self.backend)
1807
+ l_image2 = self.smooth(l_image2, axis=axis)
1808
+ l_image2 = self.ud_grade_2(l_image2, axis=axis)
1809
+
1810
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1811
+ sc_ret = scat(
1812
+ p00[0],
1813
+ s0[0],
1814
+ s1[0],
1815
+ s2[0],
1816
+ s2l[0],
1817
+ s2j1,
1818
+ s2j2,
1819
+ cross=cross,
1820
+ backend=self.backend,
1821
+ )
1299
1822
  else:
1300
- sc_ret=scat(p00,s0,s1,s2,s2l,s2j1,s2j2,cross=cross,backend=self.backend)
1301
-
1823
+ sc_ret = scat(
1824
+ p00, s0, s1, s2, s2l, s2j1, s2j2, cross=cross, backend=self.backend
1825
+ )
1826
+
1302
1827
  if calc_var:
1303
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1304
- vsc_ret=scat(vp00[0],vs0[0],vs1[0],vs2[0],vs2l[0],s2j1,s2j2,cross=cross,backend=self.backend)
1828
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1829
+ vsc_ret = scat(
1830
+ vp00[0],
1831
+ vs0[0],
1832
+ vs1[0],
1833
+ vs2[0],
1834
+ vs2l[0],
1835
+ s2j1,
1836
+ s2j2,
1837
+ cross=cross,
1838
+ backend=self.backend,
1839
+ )
1305
1840
  else:
1306
- vsc_ret=scat(vp00,vs0,vs1,vs2,vs2l,s2j1,s2j2,cross=cross,backend=self.backend)
1307
- return sc_ret,vsc_ret
1841
+ vsc_ret = scat(
1842
+ vp00,
1843
+ vs0,
1844
+ vs1,
1845
+ vs2,
1846
+ vs2l,
1847
+ s2j1,
1848
+ s2j2,
1849
+ cross=cross,
1850
+ backend=self.backend,
1851
+ )
1852
+ return sc_ret, vsc_ret
1308
1853
  else:
1309
1854
  return sc_ret
1310
1855
 
1311
- def square(self,x):
1856
+ def square(self, x):
1312
1857
  # the abs make the complex value usable for reduce_sum or mean
1313
- return scat(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1314
- self.backend.bk_square(self.backend.bk_abs(x.S0)),
1315
- self.backend.bk_square(self.backend.bk_abs(x.S1)),
1316
- self.backend.bk_square(self.backend.bk_abs(x.S2)),
1317
- self.backend.bk_square(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1318
-
1319
- def sqrt(self,x):
1858
+ return scat(
1859
+ self.backend.bk_square(self.backend.bk_abs(x.P00)),
1860
+ self.backend.bk_square(self.backend.bk_abs(x.S0)),
1861
+ self.backend.bk_square(self.backend.bk_abs(x.S1)),
1862
+ self.backend.bk_square(self.backend.bk_abs(x.S2)),
1863
+ self.backend.bk_square(self.backend.bk_abs(x.S2L)),
1864
+ x.j1,
1865
+ x.j2,
1866
+ backend=self.backend,
1867
+ )
1868
+
1869
+ def sqrt(self, x):
1320
1870
  # the abs make the complex value usable for reduce_sum or mean
1321
- return scat(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1322
- self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1323
- self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1324
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1325
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1871
+ return scat(
1872
+ self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1873
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1874
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1875
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1876
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),
1877
+ x.j1,
1878
+ x.j2,
1879
+ backend=self.backend,
1880
+ )
1881
+
1882
+ def reduce_distance(self, x, y, sigma=None):
1883
+
1884
+ if isinstance(x, scat):
1885
+ if sigma is None:
1886
+ result = self.diff_data(y.S0, x.S0, is_complex=False)
1887
+ result += self.diff_data(y.S1, x.S1)
1888
+ result += self.diff_data(y.P00, x.P00)
1889
+ result += self.diff_data(y.S2, x.S2)
1890
+ result += self.diff_data(y.S2L, x.S2L)
1891
+ else:
1892
+ result = self.diff_data(y.S0, x.S0, is_complex=False, sigma=sigma.S0)
1893
+ result += self.diff_data(y.S1, x.S1, sigma=sigma.S1)
1894
+ result += self.diff_data(y.P00, x.P00, sigma=sigma.P00)
1895
+ result += self.diff_data(y.S2, x.S2, sigma=sigma.S2)
1896
+ result += self.diff_data(y.S2L, x.S2L, sigma=sigma.S2L)
1897
+
1898
+ nval = (
1899
+ self.backend.bk_size(x.S0)
1900
+ + self.backend.bk_size(x.P00)
1901
+ + self.backend.bk_size(x.S1)
1902
+ + self.backend.bk_size(x.S2)
1903
+ + self.backend.bk_size(x.S2L)
1904
+ )
1905
+
1906
+ result /= self.backend.bk_cast(nval)
1907
+ else:
1908
+ return self.backend.bk_reduce_sum(x)
1909
+ return result
1326
1910
 
1327
- def reduce_mean(self,x,axis=None):
1911
+ def reduce_mean(self, x, axis=None):
1328
1912
  if axis is None:
1329
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))+ \
1330
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))+ \
1331
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))+ \
1332
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))+ \
1333
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1334
-
1335
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1336
- np.array(list(x.S0.shape)).prod()+ \
1337
- np.array(list(x.S1.shape)).prod()+ \
1338
- np.array(list(x.S2.shape)).prod()
1339
-
1340
- return tmp/ntmp
1913
+ tmp = (
1914
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))
1915
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))
1916
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))
1917
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))
1918
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1919
+ )
1920
+
1921
+ ntmp = (
1922
+ np.array(list(x.P00.shape)).prod()
1923
+ + np.array(list(x.S0.shape)).prod()
1924
+ + np.array(list(x.S1.shape)).prod()
1925
+ + np.array(list(x.S2.shape)).prod()
1926
+ )
1927
+
1928
+ return tmp / ntmp
1341
1929
  else:
1342
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00,axis=axis))+ \
1343
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0,axis=axis))+ \
1344
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1,axis=axis))+ \
1345
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2,axis=axis))+ \
1346
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L,axis=axis))
1347
-
1348
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1349
- np.array(list(x.S0.shape)).prod()+ \
1350
- np.array(list(x.S1.shape)).prod()+ \
1351
- np.array(list(x.S2.shape)).prod()+ \
1352
- np.array(list(x.S2L.shape)).prod()
1353
-
1354
- return tmp/ntmp
1355
-
1356
- def reduce_sum(self,x,axis=None):
1930
+ tmp = (
1931
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00, axis=axis))
1932
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0, axis=axis))
1933
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1, axis=axis))
1934
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2, axis=axis))
1935
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L, axis=axis))
1936
+ )
1937
+
1938
+ ntmp = (
1939
+ np.array(list(x.P00.shape)).prod()
1940
+ + np.array(list(x.S0.shape)).prod()
1941
+ + np.array(list(x.S1.shape)).prod()
1942
+ + np.array(list(x.S2.shape)).prod()
1943
+ + np.array(list(x.S2L.shape)).prod()
1944
+ )
1945
+
1946
+ return tmp / ntmp
1947
+
1948
+ def reduce_sum(self, x, axis=None):
1357
1949
  if axis is None:
1358
- return self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))+ \
1359
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))+ \
1360
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))+ \
1361
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))+ \
1362
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1950
+ return (
1951
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))
1952
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
1953
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
1954
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
1955
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1956
+ )
1363
1957
  else:
1364
- return scat(self.backend.bk_reduce_sum(x.P00,axis=axis),
1365
- self.backend.bk_reduce_sum(x.S0,axis=axis),
1366
- self.backend.bk_reduce_sum(x.S1,axis=axis),
1367
- self.backend.bk_reduce_sum(x.S2,axis=axis),
1368
- self.backend.bk_reduce_sum(x.S2L,axis=axis),x.j1,x.j2,backend=self.backend)
1369
-
1370
- def ldiff(self,sig,x):
1371
- return scat(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1372
- x.domult(sig.S0,x.S0)*x.domult(sig.S0,x.S0),
1373
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1374
- x.domult(sig.S2,x.S2)*x.domult(sig.S2,x.S2),
1375
- x.domult(sig.S2L,x.S2L)*x.domult(sig.S2L,x.S2L),x.j1,x.j2,backend=self.backend)
1376
-
1377
- def log(self,x):
1378
- return scat(self.backend.bk_log(x.P00),
1379
- self.backend.bk_log(x.S0),
1380
- self.backend.bk_log(x.S1),
1381
- self.backend.bk_log(x.S2),
1382
- self.backend.bk_log(x.S2L),x.j1,x.j2,backend=self.backend)
1383
- def abs(self,x):
1384
- return scat(self.backend.bk_abs(x.P00),
1385
- self.backend.bk_abs(x.S0),
1386
- self.backend.bk_abs(x.S1),
1387
- self.backend.bk_abs(x.S2),
1388
- self.backend.bk_abs(x.S2L),x.j1,x.j2,backend=self.backend)
1389
- def inv(self,x):
1390
- return scat(1/(x.P00),1/(x.S0),1/(x.S1),1/(x.S2),1/(x.S2L),x.j1,x.j2,backend=self.backend)
1958
+ return scat(
1959
+ self.backend.bk_reduce_sum(x.P00, axis=axis),
1960
+ self.backend.bk_reduce_sum(x.S0, axis=axis),
1961
+ self.backend.bk_reduce_sum(x.S1, axis=axis),
1962
+ self.backend.bk_reduce_sum(x.S2, axis=axis),
1963
+ self.backend.bk_reduce_sum(x.S2L, axis=axis),
1964
+ x.j1,
1965
+ x.j2,
1966
+ backend=self.backend,
1967
+ )
1968
+
1969
+ def ldiff(self, sig, x):
1970
+ return scat(
1971
+ x.domult(sig.P00, x.P00) * x.domult(sig.P00, x.P00),
1972
+ x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
1973
+ x.domult(sig.S1, x.S1) * x.domult(sig.S1, x.S1),
1974
+ x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
1975
+ x.domult(sig.S2L, x.S2L) * x.domult(sig.S2L, x.S2L),
1976
+ x.j1,
1977
+ x.j2,
1978
+ backend=self.backend,
1979
+ )
1980
+
1981
+ def log(self, x):
1982
+ return scat(
1983
+ self.backend.bk_log(x.P00),
1984
+ self.backend.bk_log(x.S0),
1985
+ self.backend.bk_log(x.S1),
1986
+ self.backend.bk_log(x.S2),
1987
+ self.backend.bk_log(x.S2L),
1988
+ x.j1,
1989
+ x.j2,
1990
+ backend=self.backend,
1991
+ )
1992
+
1993
+ def abs(self, x):
1994
+ return scat(
1995
+ self.backend.bk_abs(x.P00),
1996
+ self.backend.bk_abs(x.S0),
1997
+ self.backend.bk_abs(x.S1),
1998
+ self.backend.bk_abs(x.S2),
1999
+ self.backend.bk_abs(x.S2L),
2000
+ x.j1,
2001
+ x.j2,
2002
+ backend=self.backend,
2003
+ )
2004
+
2005
+ def inv(self, x):
2006
+ return scat(
2007
+ 1 / (x.P00),
2008
+ 1 / (x.S0),
2009
+ 1 / (x.S1),
2010
+ 1 / (x.S2),
2011
+ 1 / (x.S2L),
2012
+ x.j1,
2013
+ x.j2,
2014
+ backend=self.backend,
2015
+ )
1391
2016
 
1392
2017
  def one(self):
1393
- return scat(1.0,1.0,1.0,1.0,1.0,[0],[0],backend=self.backend)
2018
+ return scat(1.0, 1.0, 1.0, 1.0, 1.0, [0], [0], backend=self.backend)
1394
2019
 
1395
- @tf.function
1396
- def eval_comp_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
2020
+ @tf_function
2021
+ def eval_comp_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1397
2022
 
1398
- res=self.eval(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1399
- return res.P00,res.S0,res.S1,res.S2,res.S2L,res.j1,res.j2
2023
+ res = self.eval(image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off)
2024
+ return res.P00, res.S0, res.S1, res.S2, res.S2L, res.j1, res.j2
1400
2025
 
1401
- def eval_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1402
- p0,s0,s1,s2,s2l,j1,j2=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1403
- return scat(p0,s0,s1,s2,s2l,j1,j2,backend=self.backend)
1404
-
1405
-
1406
-
2026
+ def eval_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
2027
+ p0, s0, s1, s2, s2l, j1, j2 = self.eval_comp_fast(
2028
+ image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off
2029
+ )
2030
+ return scat(p0, s0, s1, s2, s2l, j1, j2, backend=self.backend)