foscat 3.0.8__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat.py CHANGED
@@ -1,409 +1,602 @@
1
- import foscat.FoCUS as FOC
2
- import numpy as np
3
- import tensorflow as tf
4
1
  import pickle
5
- import foscat.backend as bk
2
+ import sys
3
+
6
4
  import healpy as hp
7
-
5
+ import numpy as np
6
+
7
+ import foscat.backend as bk
8
+ import foscat.FoCUS as FOC
9
+
10
+ # Vérifier si TensorFlow est importé et défini
11
+ tf_defined = "tensorflow" in sys.modules
12
+
13
+ if tf_defined:
14
+ import tensorflow as tf
15
+
16
+ tf_function = (
17
+ tf.function
18
+ ) # Facultatif : si vous voulez utiliser TensorFlow dans ce script
19
+ else:
20
+
21
+ def tf_function(func):
22
+ return func
23
+
24
+
8
25
  def read(filename):
9
- thescat=scat(1,1,1,1,1,[0],[0])
26
+ thescat = scat(1, 1, 1, 1, 1, [0], [0])
10
27
  return thescat.read(filename)
11
-
28
+
29
+
12
30
  class scat:
13
- def __init__(self,p00,s0,s1,s2,s2l,j1,j2,cross=False,backend=None):
14
- self.bk_type='SCAT'
15
- self.P00=p00
16
- self.S0=s0
17
- self.S1=s1
18
- self.S2=s2
19
- self.S2L=s2l
20
- self.j1=j1
21
- self.j2=j2
22
- self.cross=cross
23
- self.backend=backend
24
-
25
- def set_bk_type(self,bk_type):
26
- self.bk_type=bk_type
27
-
31
+ def __init__(self, p00, s0, s1, s2, s2l, j1, j2, cross=False, backend=None):
32
+ self.bk_type = "SCAT"
33
+ self.P00 = p00
34
+ self.S0 = s0
35
+ self.S1 = s1
36
+ self.S2 = s2
37
+ self.S2L = s2l
38
+ self.j1 = j1
39
+ self.j2 = j2
40
+ self.cross = cross
41
+ self.backend = backend
42
+
43
+ def set_bk_type(self, bk_type):
44
+ self.bk_type = bk_type
45
+
28
46
  def get_j_idx(self):
29
- return self.j1,self.j2
30
-
47
+ return self.j1, self.j2
48
+
31
49
  def get_S0(self):
32
- return(self.S0)
50
+ return self.S0
33
51
 
34
52
  def get_S1(self):
35
- return(self.S1)
36
-
53
+ return self.S1
54
+
37
55
  def get_S2(self):
38
- return(self.S2)
56
+ return self.S2
39
57
 
40
58
  def get_S2L(self):
41
- return(self.S2L)
59
+ return self.S2L
42
60
 
43
61
  def get_P00(self):
44
- return(self.P00)
62
+ return self.P00
45
63
 
46
64
  def reset_P00(self):
47
- self.P00=0*self.P00
65
+ self.P00 = 0 * self.P00
48
66
 
49
67
  def constant(self):
50
- return scat(self.backend.constant(self.P00 ), \
51
- self.backend.constant(self.S0 ), \
52
- self.backend.constant(self.S1 ), \
53
- self.backend.constant(self.S2 ), \
54
- self.backend.constant(self.S2L ), \
55
- self.j1 , \
56
- self.j2 ,backend=self.backend)
57
-
58
- def domult(self,x,y):
59
- if x.dtype==y.dtype:
60
- return x*y
61
- if x.dtype=='complex64' or x.dtype=='complex128':
62
-
63
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
64
- else:
65
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
66
-
67
- def dodiv(self,x,y):
68
- if x.dtype==y.dtype:
69
- return x/y
70
- if x.dtype=='complex64' or x.dtype=='complex128':
71
-
72
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
73
- else:
74
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
75
-
76
- def domin(self,x,y):
77
- if x.dtype==y.dtype:
78
- return x-y
79
- if x.dtype=='complex64' or x.dtype=='complex128':
80
-
81
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
82
- else:
83
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
84
-
85
- def doadd(self,x,y):
86
- if x.dtype==y.dtype:
87
- return x+y
88
- if x.dtype=='complex64' or x.dtype=='complex128':
89
-
90
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
91
- else:
92
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
93
-
94
- def relu(self):
95
-
96
- return scat(self.backend.bk_relu(self.P00), \
97
- self.backend.bk_relu(self.S0), \
98
- self.backend.bk_relu(self.S1), \
99
- self.backend.bk_relu(self.S2), \
100
- self.backend.bk_relu(self.S2L), \
101
- self.j1,self.j2,backend=self.backend)
102
-
103
- def __add__(self,other):
104
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
105
- isinstance(other, bool) or isinstance(other, scat)
106
-
68
+ return scat(
69
+ self.backend.constant(self.P00),
70
+ self.backend.constant(self.S0),
71
+ self.backend.constant(self.S1),
72
+ self.backend.constant(self.S2),
73
+ self.backend.constant(self.S2L),
74
+ self.j1,
75
+ self.j2,
76
+ backend=self.backend,
77
+ )
78
+
79
+ def domult(self, x, y):
80
+ try:
81
+ return x * y
82
+ except:
83
+ if x.dtype == y.dtype:
84
+ return x * y
85
+ if self.backend.bk_is_complex(x):
86
+
87
+ return self.backend.bk_complex(
88
+ self.backend.bk_real(x) * y, self.backend.bk_imag(x) * y
89
+ )
90
+ else:
91
+ return self.backend.bk_complex(
92
+ self.backend.bk_real(y) * x, self.backend.bk_imag(y) * x
93
+ )
94
+
95
+ def dodiv(self, x, y):
96
+ try:
97
+ return x / y
98
+ except:
99
+ if x.dtype == y.dtype:
100
+ return x / y
101
+ if self.backend.bk_is_complex(x):
102
+
103
+ return self.backend.bk_complex(
104
+ self.backend.bk_real(x) / y, self.backend.bk_imag(x) / y
105
+ )
106
+ else:
107
+ return self.backend.bk_complex(
108
+ x / self.backend.bk_real(y), x / self.backend.bk_imag(y)
109
+ )
110
+
111
+ def domin(self, x, y):
112
+ try:
113
+ return x - y
114
+ except:
115
+ if x.dtype == y.dtype:
116
+ return x - y
117
+
118
+ if self.backend.bk_is_complex(x):
119
+
120
+ return self.backend.bk_complex(
121
+ self.backend.bk_real(x) - y, self.backend.bk_imag(x) - y
122
+ )
123
+ else:
124
+ return self.backend.bk_complex(
125
+ x - self.backend.bk_real(y), x - self.backend.bk_imag(y)
126
+ )
127
+
128
+ def doadd(self, x, y):
129
+ try:
130
+ return x + y
131
+ except:
132
+ if x.dtype == y.dtype:
133
+ return x + y
134
+ if self.backend.bk_is_complex(x):
135
+
136
+ return self.backend.bk_complex(
137
+ self.backend.bk_real(x) + y, self.backend.bk_imag(x) + y
138
+ )
139
+ else:
140
+ return self.backend.bk_complex(
141
+ x + self.backend.bk_real(y), x + self.backend.bk_imag(y)
142
+ )
143
+
144
+ def __add__(self, other):
145
+ assert (
146
+ isinstance(other, float)
147
+ or isinstance(other, np.float32)
148
+ or isinstance(other, int)
149
+ or isinstance(other, bool)
150
+ or isinstance(other, scat)
151
+ )
152
+
107
153
  if isinstance(other, scat):
108
- return scat(self.doadd(self.P00,other.P00), \
109
- self.doadd(self.S0, other.S0), \
110
- self.doadd(self.S1, other.S1), \
111
- self.doadd(self.S2, other.S2), \
112
- self.doadd(self.S2L, other.S2L), \
113
- self.j1,self.j2,backend=self.backend)
154
+ return scat(
155
+ self.doadd(self.P00, other.P00),
156
+ self.doadd(self.S0, other.S0),
157
+ self.doadd(self.S1, other.S1),
158
+ self.doadd(self.S2, other.S2),
159
+ self.doadd(self.S2L, other.S2L),
160
+ self.j1,
161
+ self.j2,
162
+ backend=self.backend,
163
+ )
114
164
  else:
115
- return scat((self.P00+ other), \
116
- (self.S0+ other), \
117
- (self.S1+ other), \
118
- (self.S2+ other), \
119
- (self.S2L+ other), \
120
- self.j1,self.j2,backend=self.backend)
121
-
122
- def toreal(self,value):
165
+ return scat(
166
+ (self.P00 + other),
167
+ (self.S0 + other),
168
+ (self.S1 + other),
169
+ (self.S2 + other),
170
+ (self.S2L + other),
171
+ self.j1,
172
+ self.j2,
173
+ backend=self.backend,
174
+ )
175
+
176
+ def toreal(self, value):
123
177
  if value is None:
124
178
  return None
125
-
179
+
126
180
  return self.backend.bk_real(value)
127
181
 
128
- def addcomplex(self,value,amp):
182
+ def addcomplex(self, value, amp):
129
183
  if value is None:
130
184
  return None
131
-
132
- return self.backend.bk_complex(value,amp*value)
133
-
134
- def add_complex(self,amp):
135
- return scat(self.addcomplex(self.P00,amp), \
136
- self.addcomplex(self.S0,amp), \
137
- self.addcomplex(self.S1,amp), \
138
- self.addcomplex(self.S2,amp), \
139
- self.addcomplex(self.S2L,amp), \
140
- self.j1,self.j2,backend=self.backend)
141
-
185
+
186
+ return self.backend.bk_complex(value, amp * value)
187
+
188
+ def add_complex(self, amp):
189
+ return scat(
190
+ self.addcomplex(self.P00, amp),
191
+ self.addcomplex(self.S0, amp),
192
+ self.addcomplex(self.S1, amp),
193
+ self.addcomplex(self.S2, amp),
194
+ self.addcomplex(self.S2L, amp),
195
+ self.j1,
196
+ self.j2,
197
+ backend=self.backend,
198
+ )
199
+
142
200
  def real(self):
143
- return scat(self.toreal(self.P00), \
144
- self.toreal(self.S0), \
145
- self.toreal(self.S1), \
146
- self.toreal(self.S2), \
147
- self.toreal(self.S2L), \
148
- self.j1,self.j2,backend=self.backend)
149
-
150
- def __radd__(self,other):
201
+ return scat(
202
+ self.toreal(self.P00),
203
+ self.toreal(self.S0),
204
+ self.toreal(self.S1),
205
+ self.toreal(self.S2),
206
+ self.toreal(self.S2L),
207
+ self.j1,
208
+ self.j2,
209
+ backend=self.backend,
210
+ )
211
+
212
+ def __radd__(self, other):
151
213
  return self.__add__(other)
152
214
 
153
- def __truediv__(self,other):
154
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
155
- isinstance(other, bool) or isinstance(other, scat)
156
-
215
+ def __truediv__(self, other):
216
+ assert (
217
+ isinstance(other, float)
218
+ or isinstance(other, np.float32)
219
+ or isinstance(other, int)
220
+ or isinstance(other, bool)
221
+ or isinstance(other, scat)
222
+ )
223
+
157
224
  if isinstance(other, scat):
158
- return scat(self.dodiv(self.P00, other.P00), \
159
- self.dodiv(self.S0, other.S0), \
160
- self.dodiv(self.S1, other.S1), \
161
- self.dodiv(self.S2, other.S2), \
162
- self.dodiv(self.S2L, other.S2L), \
163
- self.j1,self.j2,backend=self.backend)
225
+ return scat(
226
+ self.dodiv(self.P00, other.P00),
227
+ self.dodiv(self.S0, other.S0),
228
+ self.dodiv(self.S1, other.S1),
229
+ self.dodiv(self.S2, other.S2),
230
+ self.dodiv(self.S2L, other.S2L),
231
+ self.j1,
232
+ self.j2,
233
+ backend=self.backend,
234
+ )
164
235
  else:
165
- return scat((self.P00/ other), \
166
- (self.S0/ other), \
167
- (self.S1/ other), \
168
- (self.S2/ other), \
169
- (self.S2L/ other), \
170
- self.j1,self.j2,backend=self.backend)
171
-
172
-
173
- def __rtruediv__(self,other):
174
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
175
- isinstance(other, bool) or isinstance(other, scat)
176
-
236
+ return scat(
237
+ (self.P00 / other),
238
+ (self.S0 / other),
239
+ (self.S1 / other),
240
+ (self.S2 / other),
241
+ (self.S2L / other),
242
+ self.j1,
243
+ self.j2,
244
+ backend=self.backend,
245
+ )
246
+
247
+ def __rtruediv__(self, other):
248
+ assert (
249
+ isinstance(other, float)
250
+ or isinstance(other, np.float32)
251
+ or isinstance(other, int)
252
+ or isinstance(other, bool)
253
+ or isinstance(other, scat)
254
+ )
255
+
177
256
  if isinstance(other, scat):
178
- return scat(self.dodiv(other.P00, self.P00), \
179
- self.dodiv(other.S0 , self.S0), \
180
- self.dodiv(other.S1 , self.S1), \
181
- self.dodiv(other.S2 , self.S2), \
182
- self.dodiv(other.S2L , self.S2L), \
183
- self.j1,self.j2,backend=self.backend)
257
+ return scat(
258
+ self.dodiv(other.P00, self.P00),
259
+ self.dodiv(other.S0, self.S0),
260
+ self.dodiv(other.S1, self.S1),
261
+ self.dodiv(other.S2, self.S2),
262
+ self.dodiv(other.S2L, self.S2L),
263
+ self.j1,
264
+ self.j2,
265
+ backend=self.backend,
266
+ )
184
267
  else:
185
- return scat((other/ self.P00), \
186
- (other / self.S0), \
187
- (other / self.S1), \
188
- (other / self.S2), \
189
- (other / self.S2L), \
190
- self.j1,self.j2,backend=self.backend)
191
-
192
- def __sub__(self,other):
193
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
194
- isinstance(other, bool) or isinstance(other, scat)
195
-
268
+ return scat(
269
+ (other / self.P00),
270
+ (other / self.S0),
271
+ (other / self.S1),
272
+ (other / self.S2),
273
+ (other / self.S2L),
274
+ self.j1,
275
+ self.j2,
276
+ backend=self.backend,
277
+ )
278
+
279
+ def __sub__(self, other):
280
+ assert (
281
+ isinstance(other, float)
282
+ or isinstance(other, np.float32)
283
+ or isinstance(other, int)
284
+ or isinstance(other, bool)
285
+ or isinstance(other, scat)
286
+ )
287
+
196
288
  if isinstance(other, scat):
197
- return scat(self.domin(self.P00, other.P00), \
198
- self.domin(self.S0, other.S0), \
199
- self.domin(self.S1, other.S1), \
200
- self.domin(self.S2, other.S2), \
201
- self.domin(self.S2L, other.S2L), \
202
- self.j1,self.j2,backend=self.backend)
289
+ return scat(
290
+ self.domin(self.P00, other.P00),
291
+ self.domin(self.S0, other.S0),
292
+ self.domin(self.S1, other.S1),
293
+ self.domin(self.S2, other.S2),
294
+ self.domin(self.S2L, other.S2L),
295
+ self.j1,
296
+ self.j2,
297
+ backend=self.backend,
298
+ )
203
299
  else:
204
- return scat((self.P00- other), \
205
- (self.S0- other), \
206
- (self.S1- other), \
207
- (self.S2- other), \
208
- (self.S2L- other), \
209
- self.j1,self.j2,backend=self.backend)
210
-
211
- def __rsub__(self,other):
212
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
213
- isinstance(other, bool) or isinstance(other, scat)
214
-
300
+ return scat(
301
+ (self.P00 - other),
302
+ (self.S0 - other),
303
+ (self.S1 - other),
304
+ (self.S2 - other),
305
+ (self.S2L - other),
306
+ self.j1,
307
+ self.j2,
308
+ backend=self.backend,
309
+ )
310
+
311
+ def __rsub__(self, other):
312
+ assert (
313
+ isinstance(other, float)
314
+ or isinstance(other, np.float32)
315
+ or isinstance(other, int)
316
+ or isinstance(other, bool)
317
+ or isinstance(other, scat)
318
+ )
319
+
215
320
  if isinstance(other, scat):
216
- return scat(self.domin(other.P00,self.P00), \
217
- self.domin(other.S0, self.S0), \
218
- self.domin(other.S1, self.S1), \
219
- self.domin(other.S2, self.S2), \
220
- self.domin(other.S2L, self.S2L), \
221
- self.j1,self.j2,backend=self.backend)
321
+ return scat(
322
+ self.domin(other.P00, self.P00),
323
+ self.domin(other.S0, self.S0),
324
+ self.domin(other.S1, self.S1),
325
+ self.domin(other.S2, self.S2),
326
+ self.domin(other.S2L, self.S2L),
327
+ self.j1,
328
+ self.j2,
329
+ backend=self.backend,
330
+ )
222
331
  else:
223
- return scat((other-self.P00), \
224
- (other-self.S0), \
225
- (other-self.S1), \
226
- (other-self.S2), \
227
- (other-self.S2L), \
228
- self.j1,self.j2,backend=self.backend)
229
-
230
- def __mul__(self,other):
231
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
232
- isinstance(other, bool) or isinstance(other, scat)
233
-
332
+ return scat(
333
+ (other - self.P00),
334
+ (other - self.S0),
335
+ (other - self.S1),
336
+ (other - self.S2),
337
+ (other - self.S2L),
338
+ self.j1,
339
+ self.j2,
340
+ backend=self.backend,
341
+ )
342
+
343
+ def __mul__(self, other):
344
+ assert (
345
+ isinstance(other, float)
346
+ or isinstance(other, np.float32)
347
+ or isinstance(other, int)
348
+ or isinstance(other, bool)
349
+ or isinstance(other, scat)
350
+ )
351
+
234
352
  if isinstance(other, scat):
235
- return scat(self.domult(self.P00, other.P00), \
236
- self.domult(self.S0, other.S0), \
237
- self.domult(self.S1, other.S1), \
238
- self.domult(self.S2, other.S2), \
239
- self.domult(self.S2L, other.S2L), \
240
- self.j1,self.j2,backend=self.backend)
353
+ return scat(
354
+ self.domult(self.P00, other.P00),
355
+ self.domult(self.S0, other.S0),
356
+ self.domult(self.S1, other.S1),
357
+ self.domult(self.S2, other.S2),
358
+ self.domult(self.S2L, other.S2L),
359
+ self.j1,
360
+ self.j2,
361
+ backend=self.backend,
362
+ )
241
363
  else:
242
- return scat((self.P00* other), \
243
- (self.S0* other), \
244
- (self.S1* other), \
245
- (self.S2* other), \
246
- (self.S2L* other), \
247
- self.j1,self.j2,backend=self.backend)
364
+ return scat(
365
+ (self.P00 * other),
366
+ (self.S0 * other),
367
+ (self.S1 * other),
368
+ (self.S2 * other),
369
+ (self.S2L * other),
370
+ self.j1,
371
+ self.j2,
372
+ backend=self.backend,
373
+ )
374
+
248
375
  def relu(self):
249
- return scat(self.backend.bk_relu(self.P00),
250
- self.backend.bk_relu(self.S0),
251
- self.backend.bk_relu(self.S1),
252
- self.backend.bk_relu(self.S2),
253
- self.backend.bk_relu(self.S2L), \
254
- self.j1,self.j2,backend=self.backend)
255
-
256
-
257
- def __rmul__(self,other):
258
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
259
- isinstance(other, bool) or isinstance(other, scat)
260
-
376
+ return scat(
377
+ self.backend.bk_relu(self.P00),
378
+ self.backend.bk_relu(self.S0),
379
+ self.backend.bk_relu(self.S1),
380
+ self.backend.bk_relu(self.S2),
381
+ self.backend.bk_relu(self.S2L),
382
+ self.j1,
383
+ self.j2,
384
+ backend=self.backend,
385
+ )
386
+
387
+ def __rmul__(self, other):
388
+ assert (
389
+ isinstance(other, float)
390
+ or isinstance(other, np.float32)
391
+ or isinstance(other, int)
392
+ or isinstance(other, bool)
393
+ or isinstance(other, scat)
394
+ )
395
+
261
396
  if isinstance(other, scat):
262
- return scat(self.domult(self.P00, other.P00), \
263
- self.domult(self.S0, other.S0), \
264
- self.domult(self.S1, other.S1), \
265
- self.domult(self.S2, other.S2), \
266
- self.domult(self.S2L, other.S2L), \
267
- self.j1,self.j2,backend=self.backend)
397
+ return scat(
398
+ self.domult(self.P00, other.P00),
399
+ self.domult(self.S0, other.S0),
400
+ self.domult(self.S1, other.S1),
401
+ self.domult(self.S2, other.S2),
402
+ self.domult(self.S2L, other.S2L),
403
+ self.j1,
404
+ self.j2,
405
+ backend=self.backend,
406
+ )
268
407
  else:
269
- return scat((self.P00* other), \
270
- (self.S0* other), \
271
- (self.S1* other), \
272
- (self.S2* other), \
273
- (self.S2L* other), \
274
- self.j1,self.j2,backend=self.backend)
275
-
276
- def l1_abs(self,x):
277
- y=self.get_np(x)
278
- if y.dtype=='complex64' or y.dtype=='complex128':
279
- tmp=y.real*y.real+y.imag*y.imag
280
- tmp=np.sign(tmp)*np.sqrt(np.fabs(tmp))
281
- y=tmp
282
-
283
- return(y)
284
-
285
- def plot(self,name=None,hold=True,color='blue',lw=1,legend=True):
408
+ return scat(
409
+ (self.P00 * other),
410
+ (self.S0 * other),
411
+ (self.S1 * other),
412
+ (self.S2 * other),
413
+ (self.S2L * other),
414
+ self.j1,
415
+ self.j2,
416
+ backend=self.backend,
417
+ )
418
+
419
+ def l1_abs(self, x):
420
+ y = self.get_np(x)
421
+ if self.backend.bk_is_complex(y):
422
+ tmp = y.real * y.real + y.imag * y.imag
423
+ tmp = np.sign(tmp) * np.sqrt(np.fabs(tmp))
424
+ y = tmp
425
+
426
+ return y
427
+
428
+ def plot(self, name=None, hold=True, color="blue", lw=1, legend=True):
286
429
 
287
430
  import matplotlib.pyplot as plt
288
431
 
289
- j1,j2=self.get_j_idx()
290
-
432
+ j1, j2 = self.get_j_idx()
433
+
291
434
  if name is None:
292
- name=''
435
+ name = ""
293
436
 
294
437
  if hold:
295
- plt.figure(figsize=(16,8))
296
-
297
- test=None
438
+ plt.figure(figsize=(16, 8))
439
+
440
+ test = None
298
441
  plt.subplot(2, 2, 1)
299
- tmp=abs(self.get_np(self.S1))
300
- if len(tmp.shape)==4:
442
+ tmp = abs(self.get_np(self.S1))
443
+ if len(tmp.shape) == 4:
301
444
  for k in range(tmp.shape[3]):
302
445
  for i1 in range(tmp.shape[0]):
303
446
  for i2 in range(tmp.shape[1]):
304
447
  if test is None:
305
- test=1
306
- plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
448
+ test = 1
449
+ plt.plot(
450
+ tmp[i1, i2, :, k],
451
+ color=color,
452
+ label=r"%s $S_{1}$" % (name),
453
+ lw=lw,
454
+ )
307
455
  else:
308
- plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
456
+ plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
309
457
  else:
310
458
  for k in range(tmp.shape[2]):
311
459
  for i1 in range(tmp.shape[0]):
312
460
  if test is None:
313
- test=1
314
- plt.plot(tmp[i1,:,k],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
461
+ test = 1
462
+ plt.plot(
463
+ tmp[i1, :, k],
464
+ color=color,
465
+ label=r"%s $S_{1}$" % (name),
466
+ lw=lw,
467
+ )
315
468
  else:
316
- plt.plot(tmp[i1,:,k],color=color, lw=lw)
317
- plt.yscale('log')
318
- plt.ylabel('S1')
319
- plt.xlabel(r'$j_{1}$')
469
+ plt.plot(tmp[i1, :, k], color=color, lw=lw)
470
+ plt.yscale("log")
471
+ plt.ylabel("S1")
472
+ plt.xlabel(r"$j_{1}$")
320
473
  plt.legend()
321
474
 
322
- test=None
475
+ test = None
323
476
  plt.subplot(2, 2, 2)
324
- tmp=abs(self.get_np(self.P00))
325
- if len(tmp.shape)==4:
477
+ tmp = abs(self.get_np(self.P00))
478
+ if len(tmp.shape) == 4:
326
479
  for k in range(tmp.shape[3]):
327
480
  for i1 in range(tmp.shape[0]):
328
- for i2 in range(tmp.shape[0]):
481
+ for i2 in range(tmp.shape[1]):
329
482
  if test is None:
330
- test=1
331
- plt.plot(tmp[i1,i2,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
483
+ test = 1
484
+ plt.plot(
485
+ tmp[i1, i2, :, k],
486
+ color=color,
487
+ label=r"%s $P_{00}$" % (name),
488
+ lw=lw,
489
+ )
332
490
  else:
333
- plt.plot(tmp[i1,i2,:,k],color=color, lw=lw)
491
+ plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
334
492
  else:
335
493
  for k in range(tmp.shape[2]):
336
494
  for i1 in range(tmp.shape[0]):
337
495
  if test is None:
338
- test=1
339
- plt.plot(tmp[i1,:,k],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
496
+ test = 1
497
+ plt.plot(
498
+ tmp[i1, :, k],
499
+ color=color,
500
+ label=r"%s $P_{00}$" % (name),
501
+ lw=lw,
502
+ )
340
503
  else:
341
- plt.plot(tmp[i1,:,k],color=color, lw=lw)
342
- plt.yscale('log')
343
- plt.ylabel('P00')
344
- plt.xlabel(r'$j_{1}$')
504
+ plt.plot(tmp[i1, :, k], color=color, lw=lw)
505
+ plt.yscale("log")
506
+ plt.ylabel("P00")
507
+ plt.xlabel(r"$j_{1}$")
345
508
  plt.legend()
346
-
347
- ax1=plt.subplot(2, 2, 3)
509
+
510
+ ax1 = plt.subplot(2, 2, 3)
348
511
  ax2 = ax1.twiny()
349
- n=0
350
- tmp=abs(self.get_np(self.S2))
351
- lname=r'%s $S_{2}$' % (name)
352
- ax1.set_ylabel(r'$S_{2}$ [L1 norm]')
353
- test=None
354
- tabx=[]
355
- tabnx=[]
356
- tab2x=[]
357
- tab2nx=[]
358
- if len(tmp.shape)==5:
512
+ n = 0
513
+ tmp = abs(self.get_np(self.S2))
514
+ lname = r"%s $S_{2}$" % (name)
515
+ ax1.set_ylabel(r"$S_{2}$ [L1 norm]")
516
+ test = None
517
+ tabx = []
518
+ tabnx = []
519
+ tab2x = []
520
+ tab2nx = []
521
+ if len(tmp.shape) == 5:
359
522
  for i0 in range(tmp.shape[0]):
360
523
  for i1 in range(tmp.shape[1]):
361
- for i2 in range(j1.max()+1):
524
+ for i2 in range(j1.max() + 1):
362
525
  for i3 in range(tmp.shape[3]):
363
526
  for i4 in range(tmp.shape[4]):
364
- if j2[j1==i2].shape[0]==1:
365
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4],'.', \
366
- color=color, lw=lw)
527
+ if j2[j1 == i2].shape[0] == 1:
528
+ ax1.plot(
529
+ j2[j1 == i2] + n,
530
+ tmp[i0, i1, j1 == i2, i3, i4],
531
+ ".",
532
+ color=color,
533
+ lw=lw,
534
+ )
367
535
  else:
368
536
  if legend and test is None:
369
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
370
- color=color, label=lname, lw=lw)
371
- test=1
372
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
373
- color=color, lw=lw)
374
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
375
- tabx=tabx+[k+n for k in j2[j1==i2]]
376
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
377
- tab2nx=tab2nx+['%d'%(i2)]
378
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
379
- n=n+j2[j1==i2].shape[0]-1
537
+ ax1.plot(
538
+ j2[j1 == i2] + n,
539
+ tmp[i0, i1, j1 == i2, i3, i4],
540
+ color=color,
541
+ label=lname,
542
+ lw=lw,
543
+ )
544
+ test = 1
545
+ ax1.plot(
546
+ j2[j1 == i2] + n,
547
+ tmp[i0, i1, j1 == i2, i3, i4],
548
+ color=color,
549
+ lw=lw,
550
+ )
551
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
552
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
553
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
554
+ tab2nx = tab2nx + ["%d" % (i2)]
555
+ ax1.axvline(
556
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
557
+ )
558
+ n = n + j2[j1 == i2].shape[0] - 1
380
559
  else:
381
560
  for i0 in range(tmp.shape[0]):
382
- for i2 in range(j1.max()+1):
561
+ for i2 in range(j1.max() + 1):
383
562
  for i3 in range(tmp.shape[2]):
384
563
  for i4 in range(tmp.shape[3]):
385
- if j2[j1==i2].shape[0]==1:
386
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4],'.', \
387
- color=color, lw=lw)
564
+ if j2[j1 == i2].shape[0] == 1:
565
+ ax1.plot(
566
+ j2[j1 == i2] + n,
567
+ tmp[i0, j1 == i2, i3, i4],
568
+ ".",
569
+ color=color,
570
+ lw=lw,
571
+ )
388
572
  else:
389
573
  if legend and test is None:
390
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
391
- color=color, label=lname, lw=lw)
392
- test=1
393
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
394
- color=color, lw=lw)
395
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
396
- tabx=tabx+[k+n for k in j2[j1==i2]]
397
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
398
- tab2nx=tab2nx+['%d'%(i2)]
399
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
400
- n=n+j2[j1==i2].shape[0]-1
401
- plt.yscale('log')
402
- ax1.set_xlim(0,n+2)
574
+ ax1.plot(
575
+ j2[j1 == i2] + n,
576
+ tmp[i0, j1 == i2, i3, i4],
577
+ color=color,
578
+ label=lname,
579
+ lw=lw,
580
+ )
581
+ test = 1
582
+ ax1.plot(
583
+ j2[j1 == i2] + n,
584
+ tmp[i0, j1 == i2, i3, i4],
585
+ color=color,
586
+ lw=lw,
587
+ )
588
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
589
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
590
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
591
+ tab2nx = tab2nx + ["%d" % (i2)]
592
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
593
+ n = n + j2[j1 == i2].shape[0] - 1
594
+ plt.yscale("log")
595
+ ax1.set_xlim(0, n + 2)
403
596
  ax1.set_xticks(tabx)
404
- ax1.set_xticklabels(tabnx,fontsize=6)
405
- ax1.set_xlabel(r"$j_{2}$ ",fontsize=6)
406
-
597
+ ax1.set_xticklabels(tabnx, fontsize=6)
598
+ ax1.set_xlabel(r"$j_{2}$ ", fontsize=6)
599
+
407
600
  # Move twinned axis ticks and label from top to bottom
408
601
  ax2.xaxis.set_ticks_position("bottom")
409
602
  ax2.xaxis.set_label_position("bottom")
@@ -411,7 +604,7 @@ class scat:
411
604
  # Offset the twin axis below the host
412
605
  ax2.spines["bottom"].set_position(("axes", -0.15))
413
606
 
414
- # Turn on the frame for the twin axis, but then hide all
607
+ # Turn on the frame for the twin axis, but then hide all
415
608
  # but the bottom spine
416
609
  ax2.set_frame_on(True)
417
610
  ax2.patch.set_visible(False)
@@ -419,72 +612,102 @@ class scat:
419
612
  for sp in ax2.spines.values():
420
613
  sp.set_visible(False)
421
614
  ax2.spines["bottom"].set_visible(True)
422
- ax2.set_xlim(0,n+2)
615
+ ax2.set_xlim(0, n + 2)
423
616
  ax2.set_xticks(tab2x)
424
- ax2.set_xticklabels(tab2nx,fontsize=6)
425
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
617
+ ax2.set_xticklabels(tab2nx, fontsize=6)
618
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
426
619
  ax1.legend(frameon=0)
427
-
428
- ax1=plt.subplot(2, 2, 4)
620
+
621
+ ax1 = plt.subplot(2, 2, 4)
429
622
  ax2 = ax1.twiny()
430
- n=0
431
- tmp=abs(self.get_np(self.S2L))
432
- lname=r'%s $S2_{2}$' % (name)
433
- ax1.set_ylabel(r'$S_{2}$ [L2 norm]')
434
- test=None
435
- tabx=[]
436
- tabnx=[]
437
- tab2x=[]
438
- tab2nx=[]
439
- if len(tmp.shape)==5:
623
+ n = 0
624
+ tmp = abs(self.get_np(self.S2L))
625
+ lname = r"%s $S2_{2}$" % (name)
626
+ ax1.set_ylabel(r"$S_{2}$ [L2 norm]")
627
+ test = None
628
+ tabx = []
629
+ tabnx = []
630
+ tab2x = []
631
+ tab2nx = []
632
+ if len(tmp.shape) == 5:
440
633
  for i0 in range(tmp.shape[0]):
441
634
  for i1 in range(tmp.shape[1]):
442
- for i2 in range(j1.max()+1):
635
+ for i2 in range(j1.max() + 1):
443
636
  for i3 in range(tmp.shape[3]):
444
637
  for i4 in range(tmp.shape[4]):
445
- if j2[j1==i2].shape[0]==1:
446
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4],'.', \
447
- color=color, lw=lw)
638
+ if j2[j1 == i2].shape[0] == 1:
639
+ ax1.plot(
640
+ j2[j1 == i2] + n,
641
+ tmp[i0, i1, j1 == i2, i3, i4],
642
+ ".",
643
+ color=color,
644
+ lw=lw,
645
+ )
448
646
  else:
449
647
  if legend and test is None:
450
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
451
- color=color, label=lname, lw=lw)
452
- test=1
453
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2,i3,i4], \
454
- color=color, lw=lw)
455
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
456
- tabx=tabx+[k+n for k in j2[j1==i2]]
457
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
458
- tab2nx=tab2nx+['%d'%(i2)]
459
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
460
- n=n+j2[j1==i2].shape[0]-1
648
+ ax1.plot(
649
+ j2[j1 == i2] + n,
650
+ tmp[i0, i1, j1 == i2, i3, i4],
651
+ color=color,
652
+ label=lname,
653
+ lw=lw,
654
+ )
655
+ test = 1
656
+ ax1.plot(
657
+ j2[j1 == i2] + n,
658
+ tmp[i0, i1, j1 == i2, i3, i4],
659
+ color=color,
660
+ lw=lw,
661
+ )
662
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
663
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
664
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
665
+ tab2nx = tab2nx + ["%d" % (i2)]
666
+ ax1.axvline(
667
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
668
+ )
669
+ n = n + j2[j1 == i2].shape[0] - 1
461
670
  else:
462
671
  for i0 in range(tmp.shape[0]):
463
- for i2 in range(j1.max()+1):
672
+ for i2 in range(j1.max() + 1):
464
673
  for i3 in range(tmp.shape[2]):
465
674
  for i4 in range(tmp.shape[3]):
466
- if j2[j1==i2].shape[0]==1:
467
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4],'.', \
468
- color=color, lw=lw)
675
+ if j2[j1 == i2].shape[0] == 1:
676
+ ax1.plot(
677
+ j2[j1 == i2] + n,
678
+ tmp[i0, j1 == i2, i3, i4],
679
+ ".",
680
+ color=color,
681
+ lw=lw,
682
+ )
469
683
  else:
470
684
  if legend and test is None:
471
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
472
- color=color, label=lname, lw=lw)
473
- test=1
474
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2,i3,i4], \
475
- color=color, lw=lw)
476
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
477
- tabx=tabx+[k+n for k in j2[j1==i2]]
478
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
479
- tab2nx=tab2nx+['%d'%(i2)]
480
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
481
- n=n+j2[j1==i2].shape[0]-1
482
- plt.yscale('log')
483
- ax1.set_xlim(-1,n+3)
685
+ ax1.plot(
686
+ j2[j1 == i2] + n,
687
+ tmp[i0, j1 == i2, i3, i4],
688
+ color=color,
689
+ label=lname,
690
+ lw=lw,
691
+ )
692
+ test = 1
693
+ ax1.plot(
694
+ j2[j1 == i2] + n,
695
+ tmp[i0, j1 == i2, i3, i4],
696
+ color=color,
697
+ lw=lw,
698
+ )
699
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
700
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
701
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
702
+ tab2nx = tab2nx + ["%d" % (i2)]
703
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
704
+ n = n + j2[j1 == i2].shape[0] - 1
705
+ plt.yscale("log")
706
+ ax1.set_xlim(-1, n + 3)
484
707
  ax1.set_xticks(tabx)
485
- ax1.set_xticklabels(tabnx,fontsize=6)
486
- ax1.set_xlabel(r"$j_{2}$",fontsize=6)
487
-
708
+ ax1.set_xticklabels(tabnx, fontsize=6)
709
+ ax1.set_xlabel(r"$j_{2}$", fontsize=6)
710
+
488
711
  # Move twinned axis ticks and label from top to bottom
489
712
  ax2.xaxis.set_ticks_position("bottom")
490
713
  ax2.xaxis.set_label_position("bottom")
@@ -492,7 +715,7 @@ class scat:
492
715
  # Offset the twin axis below the host
493
716
  ax2.spines["bottom"].set_position(("axes", -0.15))
494
717
 
495
- # Turn on the frame for the twin axis, but then hide all
718
+ # Turn on the frame for the twin axis, but then hide all
496
719
  # but the bottom spine
497
720
  ax2.set_frame_on(True)
498
721
  ax2.patch.set_visible(False)
@@ -500,248 +723,351 @@ class scat:
500
723
  for sp in ax2.spines.values():
501
724
  sp.set_visible(False)
502
725
  ax2.spines["bottom"].set_visible(True)
503
- ax2.set_xlim(0,n+3)
726
+ ax2.set_xlim(0, n + 3)
504
727
  ax2.set_xticks(tab2x)
505
- ax2.set_xticklabels(tab2nx,fontsize=6)
506
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
728
+ ax2.set_xticklabels(tab2nx, fontsize=6)
729
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
507
730
  ax1.legend(frameon=0)
508
-
509
- def save(self,filename):
510
- outlist=[self.get_S0().numpy(), \
511
- self.get_S1().numpy(), \
512
- self.get_S2().numpy(), \
513
- self.get_S2L().numpy(), \
514
- self.get_P00().numpy(), \
515
- self.j1, \
516
- self.j2]
517
-
518
- myout=open("%s.pkl"%(filename),"wb")
519
- pickle.dump(outlist,myout)
731
+
732
+ def save(self, filename):
733
+ outlist = [
734
+ self.get_S0().numpy(),
735
+ self.get_S1().numpy(),
736
+ self.get_S2().numpy(),
737
+ self.get_S2L().numpy(),
738
+ self.get_P00().numpy(),
739
+ self.j1,
740
+ self.j2,
741
+ ]
742
+
743
+ myout = open("%s.pkl" % (filename), "wb")
744
+ pickle.dump(outlist, myout)
520
745
  myout.close()
521
746
 
522
-
523
- def read(self,filename):
524
-
525
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
526
- return scat(outlist[4],outlist[0],outlist[1],outlist[2],outlist[3],outlist[5],outlist[6],backend=bk.foscat_backend('numpy'))
527
-
528
- def get_np(self,x):
747
+ def read(self, filename):
748
+
749
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
750
+ return scat(
751
+ outlist[4],
752
+ outlist[0],
753
+ outlist[1],
754
+ outlist[2],
755
+ outlist[3],
756
+ outlist[5],
757
+ outlist[6],
758
+ backend=bk.foscat_backend("numpy"),
759
+ )
760
+
761
+ def get_np(self, x):
529
762
  if isinstance(x, np.ndarray):
530
763
  return x
531
764
  else:
532
765
  return x.numpy()
533
766
 
534
767
  def std(self):
535
- return np.sqrt(((abs(self.get_np(self.S0)).std())**2+ \
536
- (abs(self.get_np(self.S1)).std())**2+ \
537
- (abs(self.get_np(self.S2)).std())**2+ \
538
- (abs(self.get_np(self.S2L)).std())**2+ \
539
- (abs(self.get_np(self.P00)).std())**2)/4)
768
+ return np.sqrt(
769
+ (
770
+ (abs(self.get_np(self.S0)).std()) ** 2
771
+ + (abs(self.get_np(self.S1)).std()) ** 2
772
+ + (abs(self.get_np(self.S2)).std()) ** 2
773
+ + (abs(self.get_np(self.S2L)).std()) ** 2
774
+ + (abs(self.get_np(self.P00)).std()) ** 2
775
+ )
776
+ / 4
777
+ )
540
778
 
541
779
  def mean(self):
542
- return abs(self.get_np(self.S0).mean()+ \
543
- self.get_np(self.S1).mean()+ \
544
- self.get_np(self.S2).mean()+ \
545
- self.get_np(self.S2L).mean()+ \
546
- self.get_np(self.P00).mean())/3
780
+ return (
781
+ abs(
782
+ self.get_np(self.S0).mean()
783
+ + self.get_np(self.S1).mean()
784
+ + self.get_np(self.S2).mean()
785
+ + self.get_np(self.S2L).mean()
786
+ + self.get_np(self.P00).mean()
787
+ )
788
+ / 3
789
+ )
547
790
 
548
791
  def sqrt(self):
549
792
 
793
+ s0 = self.backend.bk_sqrt(self.S0)
794
+ s1 = self.backend.bk_sqrt(self.S1)
795
+ p00 = self.backend.bk_sqrt(self.P00)
796
+ s2 = self.backend.bk_sqrt(self.S2)
797
+ s2L = self.backend.bk_sqrt(self.S2L)
550
798
 
551
- s0 =self.backend.bk_sqrt(self.S0)
552
- s1 =self.backend.bk_sqrt(self.S1)
553
- p00=self.backend.bk_sqrt(self.P00)
554
- s2 =self.backend.bk_sqrt(self.S2)
555
- s2L=self.backend.bk_sqrt(self.S2L)
556
-
557
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
558
-
799
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
559
800
 
560
801
  def L1(self):
561
802
 
803
+ s0 = self.backend.bk_L1(self.S0)
804
+ s1 = self.backend.bk_L1(self.S1)
805
+ p00 = self.backend.bk_L1(self.P00)
806
+ s2 = self.backend.bk_L1(self.S2)
807
+ s2L = self.backend.bk_L1(self.S2L)
808
+
809
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
562
810
 
563
- s0 =self.backend.bk_L1(self.S0)
564
- s1 =self.backend.bk_L1(self.S1)
565
- p00=self.backend.bk_L1(self.P00)
566
- s2 =self.backend.bk_L1(self.S2)
567
- s2L=self.backend.bk_L1(self.S2L)
568
-
569
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
570
-
571
811
  def square_comp(self):
572
812
 
813
+ s0 = self.backend.bk_square_comp(self.S0)
814
+ s1 = self.backend.bk_square_comp(self.S1)
815
+ p00 = self.backend.bk_square_comp(self.P00)
816
+ s2 = self.backend.bk_square_comp(self.S2)
817
+ s2L = self.backend.bk_square_comp(self.S2L)
818
+
819
+ return scat(p00, s0, s1, s2, s2L, self.j1, self.j2, backend=self.backend)
573
820
 
574
- s0 =self.backend.bk_square_comp(self.S0)
575
- s1 =self.backend.bk_square_comp(self.S1)
576
- p00=self.backend.bk_square_comp(self.P00)
577
- s2 =self.backend.bk_square_comp(self.S2)
578
- s2L=self.backend.bk_square_comp(self.S2L)
579
-
580
- return scat(p00,s0,s1,s2,s2L,self.j1,self.j2,backend=self.backend)
581
-
582
- def iso_mean(self,repeat=False):
583
- shape=list(self.S2.shape)
584
- norient=self.S1.shape[2]
821
+ def iso_mean(self, repeat=False):
822
+ shape = list(self.S2.shape)
823
+ norient = self.S1.shape[2]
585
824
 
586
- S1 = self.backend.bk_reduce_mean(self.S1,2)
825
+ S1 = self.backend.bk_reduce_mean(self.S1, 2)
587
826
  if repeat:
588
- S1=self.backend.bk_reshape(self.backend.bk_repeat(S1,shape[2],1),self.S1.shape)
827
+ S1 = self.backend.bk_reshape(
828
+ self.backend.bk_repeat(S1, shape[2], 1), self.S1.shape
829
+ )
589
830
  else:
590
- S1=self.backend.bk_expand_dims(S1,-1)
591
-
831
+ S1 = self.backend.bk_expand_dims(S1, -1)
592
832
 
593
- P00 = self.backend.bk_reduce_mean(self.P00,2)
833
+ P00 = self.backend.bk_reduce_mean(self.P00, 2)
594
834
  if repeat:
595
- P00=self.backend.bk_reshape(self.backend.bk_repeat(P00,shape[2],1),self.S1.shape)
835
+ P00 = self.backend.bk_reshape(
836
+ self.backend.bk_repeat(P00, shape[2], 1), self.S1.shape
837
+ )
596
838
  else:
597
- P00=self.backend.bk_expand_dims(P00,-1)
839
+ P00 = self.backend.bk_expand_dims(P00, -1)
598
840
 
599
841
  if norient not in self.backend._iso_orient:
600
842
  self.backend.calc_iso_orient(norient)
601
-
602
- if self.S2.dtype=='complex128' or self.S2.dtype=='complex64':
603
- lmat = self.backend._iso_orient_C[norient]
843
+
844
+ if self.backend.bk_is_complex(self.S2):
845
+ lmat = self.backend._iso_orient_C[norient]
604
846
  lmat_T = self.backend._iso_orient_C_T[norient]
605
847
  else:
606
- lmat = self.backend._iso_orient[norient]
848
+ lmat = self.backend._iso_orient[norient]
607
849
  lmat_T = self.backend._iso_orient_T[norient]
608
-
609
- S2=self.backend.bk_reshape(
610
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2,[shape[0],shape[1],norient*norient]),lmat),
611
- [shape[0],shape[1],norient])
612
- S2L=self.backend.bk_reshape(
613
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2L,[shape[0],shape[1],norient*norient]),lmat),
614
- [shape[0],shape[1],norient])
615
-
616
- if repeat:
617
-
618
- S2=self.backend.bk_reshape(
619
- self.backend.backend.matmul(self.backend.bk_reshape(S2,[shape[0]*shape[1],norient]),lmat_T),
620
- self.S2.shape)
621
- S2L=self.backend.bk_reshape(
622
- self.backend.backend.matmul(self.backend.bk_reshape(S2L,[shape[0]*shape[1],norient]),lmat_T),
623
- self.S2.shape)
624
- else:
625
- S2=self.backend.bk_expand_dims(S2,-1)
626
- S2L=self.backend.bk_expand_dims(S2L,-1)
627
850
 
628
- return scat(P00,self.S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
851
+ S2 = self.backend.bk_reshape(
852
+ self.backend.backend.matmul(
853
+ self.backend.bk_reshape(
854
+ self.S2, [shape[0], shape[1], norient * norient]
855
+ ),
856
+ lmat,
857
+ ),
858
+ [shape[0], shape[1], norient],
859
+ )
860
+ S2L = self.backend.bk_reshape(
861
+ self.backend.backend.matmul(
862
+ self.backend.bk_reshape(
863
+ self.S2L, [shape[0], shape[1], norient * norient]
864
+ ),
865
+ lmat,
866
+ ),
867
+ [shape[0], shape[1], norient],
868
+ )
629
869
 
630
-
631
- def fft_ang(self,nharm=1):
632
- shape=list(self.S2.shape)
633
- norient=self.S1.shape[2]
870
+ if repeat:
634
871
 
635
- if (norient,nharm) not in self.backend._fft_1_orient:
636
- self.backend.calc_fft_orient(norient,nharm)
637
-
638
- if self.S1.dtype=='complex128' or self.S1.dtype=='complex64':
639
- lmat = self.backend._fft_1_orient_C[(norient,nharm)]
872
+ S2 = self.backend.bk_reshape(
873
+ self.backend.backend.matmul(
874
+ self.backend.bk_reshape(S2, [shape[0] * shape[1], norient]), lmat_T
875
+ ),
876
+ self.S2.shape,
877
+ )
878
+ S2L = self.backend.bk_reshape(
879
+ self.backend.backend.matmul(
880
+ self.backend.bk_reshape(S2L, [shape[0] * shape[1], norient]), lmat_T
881
+ ),
882
+ self.S2.shape,
883
+ )
640
884
  else:
641
- lmat = self.backend._fft_1_orient[(norient,nharm)]
642
-
643
- S1=self.backend.bk_reshape(
644
- self.backend.backend.matmul(self.backend.bk_reshape(self.S1,[self.S1.shape[0],self.S1.shape[1],norient]),lmat),
645
- [self.S1.shape[0],self.S1.shape[1],1+nharm])
646
-
647
- P00=self.backend.bk_reshape(
648
- self.backend.backend.matmul(self.backend.bk_reshape(self.P00,[self.S1.shape[0],self.S1.shape[1],norient]),lmat),
649
- [self.S1.shape[0],self.S1.shape[1],1+nharm])
650
-
651
-
652
- if self.S2.dtype=='complex128' or self.S2.dtype=='complex64':
653
- lmat = self.backend._fft_2_orient_C[(norient,nharm)]
654
- else:
655
- lmat = self.backend._fft_2_orient[(norient,nharm)]
656
-
657
- S2=self.backend.bk_reshape(
658
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2,[shape[0],shape[1],norient*norient]),lmat),
659
- [shape[0],shape[1],1+nharm,1+nharm])
660
- S2L=self.backend.bk_reshape(
661
- self.backend.backend.matmul(self.backend.bk_reshape(self.S2L,[shape[0],shape[1],norient*norient]),lmat),
662
- [shape[0],shape[1],1+nharm,1+nharm])
885
+ S2 = self.backend.bk_expand_dims(S2, -1)
886
+ S2L = self.backend.bk_expand_dims(S2L, -1)
887
+
888
+ return scat(P00, self.S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
663
889
 
664
- return scat(P00,self.S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
890
+ def fft_ang(self, nharm=1, imaginary=False):
891
+ shape = list(self.S2.shape)
892
+ norient = self.S1.shape[2]
665
893
 
894
+ nout = 1 + nharm
895
+ if imaginary:
896
+ nout = 1 + nharm * 2
666
897
 
667
- def iso_std(self,repeat=False):
898
+ if (norient, nharm) not in self.backend._fft_1_orient:
899
+ self.backend.calc_fft_orient(norient, nharm, imaginary)
668
900
 
669
- val=(self-self.iso_mean(repeat=True)).square_comp()
901
+ if self.backend.bk_is_complex(self.S1):
902
+ lmat = self.backend._fft_1_orient_C[(norient, nharm, imaginary)]
903
+ else:
904
+ lmat = self.backend._fft_1_orient[(norient, nharm, imaginary)]
905
+
906
+ S1 = self.backend.bk_reshape(
907
+ self.backend.backend.matmul(
908
+ self.backend.bk_reshape(
909
+ self.S1, [self.S1.shape[0], self.S1.shape[1], norient]
910
+ ),
911
+ lmat,
912
+ ),
913
+ [self.S1.shape[0], self.S1.shape[1], nout],
914
+ )
915
+
916
+ P00 = self.backend.bk_reshape(
917
+ self.backend.backend.matmul(
918
+ self.backend.bk_reshape(
919
+ self.P00, [self.S1.shape[0], self.S1.shape[1], norient]
920
+ ),
921
+ lmat,
922
+ ),
923
+ [self.S1.shape[0], self.S1.shape[1], nout],
924
+ )
925
+
926
+ if self.backend.bk_is_complex(self.S2):
927
+ lmat = self.backend._fft_2_orient_C[(norient, nharm, imaginary)]
928
+ else:
929
+ lmat = self.backend._fft_2_orient[(norient, nharm, imaginary)]
930
+
931
+ S2 = self.backend.bk_reshape(
932
+ self.backend.backend.matmul(
933
+ self.backend.bk_reshape(
934
+ self.S2, [shape[0], shape[1], norient * norient]
935
+ ),
936
+ lmat,
937
+ ),
938
+ [shape[0], shape[1], nout, nout],
939
+ )
940
+ S2L = self.backend.bk_reshape(
941
+ self.backend.backend.matmul(
942
+ self.backend.bk_reshape(
943
+ self.S2L, [shape[0], shape[1], norient * norient]
944
+ ),
945
+ lmat,
946
+ ),
947
+ [shape[0], shape[1], nout, nout],
948
+ )
949
+
950
+ return scat(P00, self.S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
951
+
952
+ def iso_std(self, repeat=False):
953
+
954
+ val = (self - self.iso_mean(repeat=True)).square_comp()
670
955
  return (val.iso_mean(repeat=repeat)).L1()
671
956
 
672
957
  # ---------------------------------------------−---------
673
- def cleanval(self,x):
674
- x=x.numpy()
675
- x[np.isfinite(x)==False]=np.median(x[np.isfinite(x)])
958
+ def cleanval(self, x):
959
+ x = x.numpy()
960
+ x[~np.isfinite(x)] = np.median(x[np.isfinite(x)])
676
961
  return x
677
962
 
678
963
  def filter_inf(self):
679
- S1 = self.cleanval(self.S1)
680
- S0 = self.cleanval(self.S0)
964
+ S1 = self.cleanval(self.S1)
965
+ S0 = self.cleanval(self.S0)
681
966
  P00 = self.cleanval(self.P00)
682
- S2 = self.cleanval(self.S2)
967
+ S2 = self.cleanval(self.S2)
683
968
  S2L = self.cleanval(self.S2L)
684
969
 
685
- return scat(P00,S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
970
+ return scat(P00, S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
686
971
 
687
972
  # ---------------------------------------------−---------
688
- def interp(self,nscale,extend=False,constant=False,threshold=1E30,use_mask=False):
689
-
690
- if nscale+2>self.S1.shape[1]:
691
- print('Can not *interp* %d with a statistic described over %d'%(nscale,self.S1.shape[1]))
692
- return scat(self.P00,self.S0,self.S1,self.S2,self.S2L,self.j1,self.j2,backend=self.backend)
693
- if isinstance(self.S1,np.ndarray):
694
- s1=self.S1
695
- p0=self.P00
696
- s2=self.S2
697
- s2l=self.S2L
973
+ def interp(
974
+ self, nscale, extend=False, constant=False, threshold=1e30, use_mask=False
975
+ ):
976
+
977
+ if nscale + 2 > self.S1.shape[1]:
978
+ print(
979
+ "Can not *interp* %d with a statistic described over %d"
980
+ % (nscale, self.S1.shape[1])
981
+ )
982
+ return scat(
983
+ self.P00,
984
+ self.S0,
985
+ self.S1,
986
+ self.S2,
987
+ self.S2L,
988
+ self.j1,
989
+ self.j2,
990
+ backend=self.backend,
991
+ )
992
+ if isinstance(self.S1, np.ndarray):
993
+ s1 = self.S1
994
+ p0 = self.P00
995
+ s2 = self.S2
996
+ s2l = self.S2L
698
997
  else:
699
- s1=self.S1.numpy()
700
- p0=self.P00.numpy()
701
- s2=self.S2.numpy()
702
- s2l=self.S2L.numpy()
703
-
704
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
705
-
706
- if isinstance(threshold,scat):
707
- if isinstance(threshold.S1,np.ndarray):
708
- s1th=threshold.S1
709
- p0th=threshold.P00
710
- s2th=threshold.S2
711
- s2lth=threshold.S2L
998
+ s1 = self.S1.numpy()
999
+ p0 = self.P00.numpy()
1000
+ s2 = self.S2.numpy()
1001
+ s2l = self.S2L.numpy()
1002
+
1003
+ print(s1.sum(), p0.sum(), s2.sum(), s2l.sum())
1004
+
1005
+ if isinstance(threshold, scat):
1006
+ if isinstance(threshold.S1, np.ndarray):
1007
+ s1th = threshold.S1
1008
+ p0th = threshold.P00
1009
+ s2th = threshold.S2
1010
+ s2lth = threshold.S2L
712
1011
  else:
713
- s1th=threshold.S1.numpy()
714
- p0th=threshold.P00.numpy()
715
- s2th=threshold.S2.numpy()
716
- s2lth=threshold.S2L.numpy()
1012
+ s1th = threshold.S1.numpy()
1013
+ p0th = threshold.P00.numpy()
1014
+ s2th = threshold.S2.numpy()
1015
+ s2lth = threshold.S2L.numpy()
717
1016
  else:
718
- s1th=threshold+0*s1
719
- p0th=threshold+0*p0
720
- s2th=threshold+0*s2
721
- s2lth=threshold+0*s2l
1017
+ s1th = threshold + 0 * s1
1018
+ p0th = threshold + 0 * p0
1019
+ s2th = threshold + 0 * s2
1020
+ s2lth = threshold + 0 * s2l
722
1021
 
723
1022
  for k in range(nscale):
724
1023
  if constant:
725
- s1[:,nscale-1-k,:]=s1[:,nscale-k,:]
726
- p0[:,nscale-1-k,:]=p0[:,nscale-k,:]
1024
+ s1[:, nscale - 1 - k, :] = s1[:, nscale - k, :]
1025
+ p0[:, nscale - 1 - k, :] = p0[:, nscale - k, :]
727
1026
  else:
728
- idx=np.where((s1[:,nscale+1-k,:]>0)*(s1[:,nscale+2-k,:]>0)*(s1[:,nscale-k,:]<s1th[:,nscale-k,:]))
729
- if len(idx[0])>0:
730
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(s1[idx[0],nscale+1-k,idx[1]])-2*np.log(s1[idx[0],nscale+2-k,idx[1]]))
731
- idx=np.where((s1[:,nscale-k,:]>0)*(s1[:,nscale+1-k,:]>0)*(s1[:,nscale-1-k,:]<s1th[:,nscale-1-k,:]))
732
- if len(idx[0])>0:
733
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(s1[idx[0],nscale-k,idx[1]])-np.log(s1[idx[0],nscale+1-k,idx[1]]))
734
-
735
- idx=np.where((p0[:,nscale+1-k,:]>0)*(p0[:,nscale+2-k,:]>0)*(p0[:,nscale-k,:]<p0th[:,nscale-k,:]))
736
- if len(idx[0])>0:
737
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(p0[idx[0],nscale+1-k,idx[1]])-2*np.log(p0[idx[0],nscale+2-k,idx[1]]))
738
-
739
- idx=np.where((p0[:,nscale-k,:]>0)*(p0[:,nscale+1-k,:]>0)*(p0[:,nscale-1-k,:]<p0th[:,nscale-1-k,:]))
740
- if len(idx[0])>0:
741
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(p0[idx[0],nscale-k,idx[1]])-np.log(p0[idx[0],nscale+1-k,idx[1]]))
742
-
743
-
744
- j1,j2=self.get_j_idx()
1027
+ idx = np.where(
1028
+ (s1[:, nscale + 1 - k, :] > 0)
1029
+ * (s1[:, nscale + 2 - k, :] > 0)
1030
+ * (s1[:, nscale - k, :] < s1th[:, nscale - k, :])
1031
+ )
1032
+ if len(idx[0]) > 0:
1033
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1034
+ 3 * np.log(s1[idx[0], nscale + 1 - k, idx[1]])
1035
+ - 2 * np.log(s1[idx[0], nscale + 2 - k, idx[1]])
1036
+ )
1037
+ idx = np.where(
1038
+ (s1[:, nscale - k, :] > 0)
1039
+ * (s1[:, nscale + 1 - k, :] > 0)
1040
+ * (s1[:, nscale - 1 - k, :] < s1th[:, nscale - 1 - k, :])
1041
+ )
1042
+ if len(idx[0]) > 0:
1043
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1044
+ 2 * np.log(s1[idx[0], nscale - k, idx[1]])
1045
+ - np.log(s1[idx[0], nscale + 1 - k, idx[1]])
1046
+ )
1047
+
1048
+ idx = np.where(
1049
+ (p0[:, nscale + 1 - k, :] > 0)
1050
+ * (p0[:, nscale + 2 - k, :] > 0)
1051
+ * (p0[:, nscale - k, :] < p0th[:, nscale - k, :])
1052
+ )
1053
+ if len(idx[0]) > 0:
1054
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1055
+ 3 * np.log(p0[idx[0], nscale + 1 - k, idx[1]])
1056
+ - 2 * np.log(p0[idx[0], nscale + 2 - k, idx[1]])
1057
+ )
1058
+
1059
+ idx = np.where(
1060
+ (p0[:, nscale - k, :] > 0)
1061
+ * (p0[:, nscale + 1 - k, :] > 0)
1062
+ * (p0[:, nscale - 1 - k, :] < p0th[:, nscale - 1 - k, :])
1063
+ )
1064
+ if len(idx[0]) > 0:
1065
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
1066
+ 2 * np.log(p0[idx[0], nscale - k, idx[1]])
1067
+ - np.log(p0[idx[0], nscale + 1 - k, idx[1]])
1068
+ )
1069
+
1070
+ j1, j2 = self.get_j_idx()
745
1071
 
746
1072
  for k in range(nscale):
747
1073
 
@@ -754,647 +1080,951 @@ class scat:
754
1080
  s2l[:,i0]=np.exp(2*np.log(s2l[:,i1])-np.log(s2l[:,i2]))
755
1081
  """
756
1082
 
757
- for l in range(nscale-k):
758
- i0=np.where((j1==nscale-1-k-l)*(j2==nscale-1-k))[0]
759
- i1=np.where((j1==nscale-1-k-l)*(j2==nscale -k))[0]
760
- i2=np.where((j1==nscale-1-k-l)*(j2==nscale+1-k))[0]
761
- i3=np.where((j1==nscale-1-k-l)*(j2==nscale+2-k))[0]
762
-
1083
+ for l_scale in range(nscale - k):
1084
+ i0 = np.where(
1085
+ (j1 == nscale - 1 - k - l_scale) * (j2 == nscale - 1 - k)
1086
+ )[0]
1087
+ i1 = np.where((j1 == nscale - 1 - k - l_scale) * (j2 == nscale - k))[0]
1088
+ i2 = np.where(
1089
+ (j1 == nscale - 1 - k - l_scale) * (j2 == nscale + 1 - k)
1090
+ )[0]
1091
+ i3 = np.where(
1092
+ (j1 == nscale - 1 - k - l_scale) * (j2 == nscale + 2 - k)
1093
+ )[0]
1094
+
763
1095
  if constant:
764
- s2[:,i0]=s2[:,i1]
765
- s2l[:,i0]=s2l[:,i1]
1096
+ s2[:, i0] = s2[:, i1]
1097
+ s2l[:, i0] = s2l[:, i1]
766
1098
  else:
767
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
768
- if len(idx[0])>0:
769
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
770
-
771
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
772
- if len(idx[0])>0:
773
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
774
-
775
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
776
- if len(idx[0])>0:
777
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
778
-
779
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
780
- if len(idx[0])>0:
781
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
782
-
1099
+ idx = np.where(
1100
+ (s2[:, i2] > 0) * (s2[:, i3] > 0) * (s2[:, i2] < s2th[:, i2])
1101
+ )
1102
+ if len(idx[0]) > 0:
1103
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1104
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
1105
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
1106
+ )
1107
+
1108
+ idx = np.where(
1109
+ (s2[:, i1] > 0) * (s2[:, i2] > 0) * (s2[:, i1] < s2th[:, i1])
1110
+ )
1111
+ if len(idx[0]) > 0:
1112
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1113
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1114
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1115
+ )
1116
+
1117
+ idx = np.where(
1118
+ (s2l[:, i2] > 0)
1119
+ * (s2l[:, i3] > 0)
1120
+ * (s2l[:, i2] < s2lth[:, i2])
1121
+ )
1122
+ if len(idx[0]) > 0:
1123
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1124
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1125
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1126
+ )
1127
+
1128
+ idx = np.where(
1129
+ (s2l[:, i1] > 0)
1130
+ * (s2l[:, i2] > 0)
1131
+ * (s2l[:, i1] < s2lth[:, i1])
1132
+ )
1133
+ if len(idx[0]) > 0:
1134
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1135
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1136
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1137
+ )
1138
+
783
1139
  if extend:
784
1140
  for k in range(nscale):
785
- for l in range(1,nscale):
786
- i0=np.where((j1==2*nscale-1-k)*(j2==2*nscale-1-k-l))[0]
787
- i1=np.where((j1==2*nscale-1-k)*(j2==2*nscale -k-l))[0]
788
- i2=np.where((j1==2*nscale-1-k)*(j2==2*nscale+1-k-l))[0]
789
- i3=np.where((j1==2*nscale-1-k)*(j2==2*nscale+2-k-l))[0]
1141
+ for l_scale in range(1, nscale):
1142
+ i0 = np.where(
1143
+ (j1 == 2 * nscale - 1 - k)
1144
+ * (j2 == 2 * nscale - 1 - k - l_scale)
1145
+ )[0]
1146
+ i1 = np.where(
1147
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale - k - l_scale)
1148
+ )[0]
1149
+ i2 = np.where(
1150
+ (j1 == 2 * nscale - 1 - k)
1151
+ * (j2 == 2 * nscale + 1 - k - l_scale)
1152
+ )[0]
1153
+ i3 = np.where(
1154
+ (j1 == 2 * nscale - 1 - k)
1155
+ * (j2 == 2 * nscale + 2 - k - l_scale)
1156
+ )[0]
790
1157
  if constant:
791
- s2[:,i0]=s2[:,i1]
792
- s2l[:,i0]=s2l[:,i1]
1158
+ s2[:, i0] = s2[:, i1]
1159
+ s2l[:, i0] = s2l[:, i1]
793
1160
  else:
794
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
795
- if len(idx[0])>0:
796
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
797
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
798
- if len(idx[0])>0:
799
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
800
-
801
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
802
- if len(idx[0])>0:
803
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
804
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
805
- if len(idx[0])>0:
806
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
807
-
808
- s1[np.isnan(s1)]=0.0
809
- p0[np.isnan(p0)]=0.0
810
- s2[np.isnan(s2)]=0.0
811
- s2l[np.isnan(s2l)]=0.0
812
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
813
-
814
- return scat(self.backend.constant(p0),self.S0,
815
- self.backend.constant(s1),
816
- self.backend.constant(s2),
817
- self.backend.constant(s2l),self.j1,self.j2,backend=self.backend)
1161
+ idx = np.where(
1162
+ (s2[:, i2] > 0)
1163
+ * (s2[:, i3] > 0)
1164
+ * (s2[:, i2] < s2th[:, i2])
1165
+ )
1166
+ if len(idx[0]) > 0:
1167
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1168
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
1169
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
1170
+ )
1171
+ idx = np.where(
1172
+ (s2[:, i1] > 0)
1173
+ * (s2[:, i2] > 0)
1174
+ * (s2[:, i1] < s2th[:, i1])
1175
+ )
1176
+ if len(idx[0]) > 0:
1177
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1178
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1179
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1180
+ )
1181
+
1182
+ idx = np.where(
1183
+ (s2l[:, i2] > 0)
1184
+ * (s2l[:, i3] > 0)
1185
+ * (s2l[:, i2] < s2lth[:, i2])
1186
+ )
1187
+ if len(idx[0]) > 0:
1188
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1189
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1190
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1191
+ )
1192
+ idx = np.where(
1193
+ (s2l[:, i1] > 0)
1194
+ * (s2l[:, i2] > 0)
1195
+ * (s2l[:, i1] < s2lth[:, i1])
1196
+ )
1197
+ if len(idx[0]) > 0:
1198
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1199
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1200
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1201
+ )
1202
+
1203
+ s1[np.isnan(s1)] = 0.0
1204
+ p0[np.isnan(p0)] = 0.0
1205
+ s2[np.isnan(s2)] = 0.0
1206
+ s2l[np.isnan(s2l)] = 0.0
1207
+ print(s1.sum(), p0.sum(), s2.sum(), s2l.sum())
1208
+
1209
+ return scat(
1210
+ self.backend.constant(p0),
1211
+ self.S0,
1212
+ self.backend.constant(s1),
1213
+ self.backend.constant(s2),
1214
+ self.backend.constant(s2l),
1215
+ self.j1,
1216
+ self.j2,
1217
+ backend=self.backend,
1218
+ )
818
1219
 
819
1220
  # ---------------------------------------------−---------
820
1221
  def flatten(self):
821
- if isinstance(self.S1,np.ndarray):
822
- return np.concatenate([self.S0.flatten(),
823
- self.S1.flatten(),
824
- self.P00.flatten(),
825
- self.S2.flatten(),
826
- self.S2L.flatten()],0)
1222
+ if isinstance(self.S1, np.ndarray):
1223
+ return np.concatenate(
1224
+ [
1225
+ self.S0.flatten(),
1226
+ self.S1.flatten(),
1227
+ self.P00.flatten(),
1228
+ self.S2.flatten(),
1229
+ self.S2L.flatten(),
1230
+ ],
1231
+ 0,
1232
+ )
827
1233
  else:
828
- return self.backend.bk_concat([self.backend.bk_flattenR(self.S0),
829
- self.backend.bk_flattenR(self.S1),
830
- self.backend.bk_flattenR(self.P00),
831
- self.backend.bk_flattenR(self.S2),
832
- self.backend.bk_flattenR(self.S2)],axis=0)
1234
+ return self.backend.bk_concat(
1235
+ [
1236
+ self.backend.bk_flattenR(self.S0),
1237
+ self.backend.bk_flattenR(self.S1),
1238
+ self.backend.bk_flattenR(self.P00),
1239
+ self.backend.bk_flattenR(self.S2),
1240
+ self.backend.bk_flattenR(self.S2),
1241
+ ],
1242
+ axis=0,
1243
+ )
833
1244
 
834
1245
  # ---------------------------------------------−---------
835
1246
  def flattenMask(self):
836
- if isinstance(self.S1,np.ndarray):
837
- tmp=np.expand_dims(np.concatenate([self.S1[0].flatten(),
838
- self.P00[0].flatten(),
839
- self.S2[0].flatten(),
840
- self.S2L[0].flatten()],0),0)
841
- for k in range(1,self.P00.shape[0]):
842
- tmp=np.concatenate([tmp,np.expand_dims(np.concatenate([self.S1[k].flatten(),
843
- self.P00[k].flatten(),
844
- self.S2[k].flatten(),
845
- self.S2L[k].flatten()],0),0)],0)
846
-
847
-
848
- return np.concatenate([tmp,np.expand_dims(self.S0,1)],1)
1247
+ if isinstance(self.S1, np.ndarray):
1248
+ tmp = np.expand_dims(
1249
+ np.concatenate(
1250
+ [
1251
+ self.S1[0].flatten(),
1252
+ self.P00[0].flatten(),
1253
+ self.S2[0].flatten(),
1254
+ self.S2L[0].flatten(),
1255
+ ],
1256
+ 0,
1257
+ ),
1258
+ 0,
1259
+ )
1260
+ for k in range(1, self.P00.shape[0]):
1261
+ tmp = np.concatenate(
1262
+ [
1263
+ tmp,
1264
+ np.expand_dims(
1265
+ np.concatenate(
1266
+ [
1267
+ self.S1[k].flatten(),
1268
+ self.P00[k].flatten(),
1269
+ self.S2[k].flatten(),
1270
+ self.S2L[k].flatten(),
1271
+ ],
1272
+ 0,
1273
+ ),
1274
+ 0,
1275
+ ),
1276
+ ],
1277
+ 0,
1278
+ )
1279
+
1280
+ return np.concatenate([tmp, np.expand_dims(self.S0, 1)], 1)
849
1281
  else:
850
- tmp=self.backend.bk_expand_dims(self.backend.bk_concat([self.backend.bk_flattenR(self.S1[0]),
851
- self.backend.bk_flattenR(self.P00[0]),
852
- self.backend.bk_flattenR(self.S2[0]),
853
- self.backend.bk_flattenR(self.S2[0])],axis=0),0)
854
- for k in range(1,self.P00.shape[0]):
855
- ltmp=self.backend.bk_expand_dims(self.backend.bk_concat([self.backend.bk_flattenR(self.S1[k]),
856
- self.backend.bk_flattenR(self.P00[k]),
857
- self.backend.bk_flattenR(self.S2[k]),
858
- self.backend.bk_flattenR(self.S2[k])],axis=0),0)
859
- tmp=self.backend.bk_concat([tmp,ltmp],0)
860
-
861
- return self.backend.bk_concat([tmp,self.backend.bk_expand_dims(self.S0,1)],1)
862
-
1282
+ tmp = self.backend.bk_expand_dims(
1283
+ self.backend.bk_concat(
1284
+ [
1285
+ self.backend.bk_flattenR(self.S1[0]),
1286
+ self.backend.bk_flattenR(self.P00[0]),
1287
+ self.backend.bk_flattenR(self.S2[0]),
1288
+ self.backend.bk_flattenR(self.S2[0]),
1289
+ ],
1290
+ axis=0,
1291
+ ),
1292
+ 0,
1293
+ )
1294
+ for k in range(1, self.P00.shape[0]):
1295
+ ltmp = self.backend.bk_expand_dims(
1296
+ self.backend.bk_concat(
1297
+ [
1298
+ self.backend.bk_flattenR(self.S1[k]),
1299
+ self.backend.bk_flattenR(self.P00[k]),
1300
+ self.backend.bk_flattenR(self.S2[k]),
1301
+ self.backend.bk_flattenR(self.S2[k]),
1302
+ ],
1303
+ axis=0,
1304
+ ),
1305
+ 0,
1306
+ )
1307
+ tmp = self.backend.bk_concat([tmp, ltmp], 0)
1308
+
1309
+ return self.backend.bk_concat(
1310
+ [tmp, self.backend.bk_expand_dims(self.S0, 1)], 1
1311
+ )
1312
+
863
1313
  # ---------------------------------------------−---------
864
- def model(self,i__y,add=0,dx=3,dell=2,weigth=None,inverse=False):
1314
+ def model(self, i__y, add=0, dx=3, dell=2, weigth=None, inverse=False):
865
1315
 
866
- if i__y.shape[0]<dx+1:
867
- l__dx=i__y.shape[0]-1
1316
+ if i__y.shape[0] < dx + 1:
1317
+ l__dx = i__y.shape[0] - 1
868
1318
  else:
869
- l__dx=dx
1319
+ l__dx = dx
870
1320
 
871
- if i__y.shape[0]<dell:
872
- l__dell=0
1321
+ if i__y.shape[0] < dell:
1322
+ l__dell = 0
873
1323
  else:
874
- l__dell=dell
1324
+ l__dell = dell
875
1325
 
876
- if l__dx<2:
877
- res=np.zeros([i__y.shape[0]+add])
1326
+ if l__dx < 2:
1327
+ res = np.zeros([i__y.shape[0] + add])
878
1328
  if inverse:
879
- res[:-add]=i__y
1329
+ res[:-add] = i__y
880
1330
  else:
881
- res[add:]=i__y[0:]
1331
+ res[add:] = i__y[0:]
882
1332
  return res
883
1333
 
884
1334
  if weigth is None:
885
- w=2**(np.arange(l__dx))
1335
+ w = 2 ** (np.arange(l__dx))
886
1336
  else:
887
1337
  if not inverse:
888
- w=weigth[0:l__dx]
1338
+ w = weigth[0:l__dx]
889
1339
  else:
890
- w=weigth[-l__dx:]
1340
+ w = weigth[-l__dx:]
891
1341
 
892
- x=np.arange(l__dx)+1
1342
+ x = np.arange(l__dx) + 1
893
1343
  if not inverse:
894
- y=np.log(i__y[1:l__dx+1])
1344
+ y = np.log(i__y[1 : l__dx + 1])
895
1345
  else:
896
- y=np.log(i__y[-(l__dx+1):-1])
1346
+ y = np.log(i__y[-(l__dx + 1) : -1])
897
1347
 
898
- r=np.polyfit(x,y,1,w=w)
1348
+ r = np.polyfit(x, y, 1, w=w)
899
1349
 
900
1350
  if inverse:
901
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-1)+r[1])
902
- res[:-(l__dell+add)]=i__y[:-l__dell]
1351
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - 1) + r[1])
1352
+ res[: -(l__dell + add)] = i__y[:-l__dell]
903
1353
  else:
904
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-add)+r[1])
905
- res[l__dell+add:]=i__y[l__dell:]
1354
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - add) + r[1])
1355
+ res[l__dell + add :] = i__y[l__dell:]
906
1356
  return res
907
1357
 
908
- def findn(self,n):
909
- d=np.sqrt(1+8*n)
910
- return int((d-1)/2)
1358
+ def findn(self, n):
1359
+ d = np.sqrt(1 + 8 * n)
1360
+ return int((d - 1) / 2)
911
1361
 
912
- def findidx(self,s2):
913
- i1=np.zeros([s2.shape[1]],dtype='int')
914
- i2=np.zeros([s2.shape[1]],dtype='int')
915
- n=0
1362
+ def findidx(self, s2):
1363
+ i1 = np.zeros([s2.shape[1]], dtype="int")
1364
+ i2 = np.zeros([s2.shape[1]], dtype="int")
1365
+ n = 0
916
1366
  for k in range(self.findn(s2.shape[1])):
917
- i1[n:n+k+1]=np.arange(k+1)
918
- i2[n:n+k+1]=k
919
- n=n+k+1
920
- return i1,i2
921
-
922
- def extrapol_s2(self,add,lnorm=1):
923
- if lnorm==1:
924
- s2=self.S2.numpy()
925
- if lnorm==2:
926
- s2=self.S2L.numpy()
927
- i1,i2=self.findidx(s2)
928
-
929
- so2=np.zeros([s2.shape[0],(self.findn(s2.shape[1])+add)*(self.findn(s2.shape[1])+add+1)//2,s2.shape[2],s2.shape[3]])
930
- oi1,oi2=self.findidx(so2)
931
- for l in range(s2.shape[0]):
1367
+ i1[n : n + k + 1] = np.arange(k + 1)
1368
+ i2[n : n + k + 1] = k
1369
+ n = n + k + 1
1370
+ return i1, i2
1371
+
1372
+ def extrapol_s2(self, add, lnorm=1):
1373
+ if lnorm == 1:
1374
+ s2 = self.S2.numpy()
1375
+ if lnorm == 2:
1376
+ s2 = self.S2L.numpy()
1377
+ i1, i2 = self.findidx(s2)
1378
+
1379
+ so2 = np.zeros(
1380
+ [
1381
+ s2.shape[0],
1382
+ (self.findn(s2.shape[1]) + add)
1383
+ * (self.findn(s2.shape[1]) + add + 1)
1384
+ // 2,
1385
+ s2.shape[2],
1386
+ s2.shape[3],
1387
+ ]
1388
+ )
1389
+ oi1, oi2 = self.findidx(so2)
1390
+ for l_batch in range(s2.shape[0]):
932
1391
  for k in range(self.findn(s2.shape[1])):
933
1392
  for i in range(s2.shape[2]):
934
1393
  for j in range(s2.shape[3]):
935
- tmp=self.model(s2[l,i2==k,i,j],dx=4,dell=1,add=add,weigth=np.array([1,2,2,2]))
936
- tmp[np.isnan(tmp)]=0.0
937
- so2[l,oi2==k+add,i,j]=tmp
938
-
939
-
940
- for l in range(s2.shape[0]):
941
- for k in range(add+1,-1,-1):
942
- lidx=np.where(oi2-oi1==k)[0]
943
- lidx2=np.where(oi2-oi1==k+1)[0]
1394
+ tmp = self.model(
1395
+ s2[l_batch, i2 == k, i, j],
1396
+ dx=4,
1397
+ dell=1,
1398
+ add=add,
1399
+ weigth=np.array([1, 2, 2, 2]),
1400
+ )
1401
+ tmp[np.isnan(tmp)] = 0.0
1402
+ so2[l_batch, oi2 == k + add, i, j] = tmp
1403
+
1404
+ for l_batch in range(s2.shape[0]):
1405
+ for k in range(add + 1, -1, -1):
1406
+ lidx = np.where(oi2 - oi1 == k)[0]
1407
+ lidx2 = np.where(oi2 - oi1 == k + 1)[0]
944
1408
  for i in range(s2.shape[2]):
945
1409
  for j in range(s2.shape[3]):
946
- so2[l,lidx[0:add+2-k],i,j]=so2[l,lidx2[0:add+2-k],i,j]
1410
+ so2[l_batch, lidx[0 : add + 2 - k], i, j] = so2[
1411
+ l_batch, lidx2[0 : add + 2 - k], i, j
1412
+ ]
947
1413
 
948
- return(so2)
1414
+ return so2
949
1415
 
950
- def extrapol_s1(self,i_s1,add):
951
- s1=i_s1.numpy()
952
- so1=np.zeros([s1.shape[0],s1.shape[1]+add,s1.shape[2]])
1416
+ def extrapol_s1(self, i_s1, add):
1417
+ s1 = i_s1.numpy()
1418
+ so1 = np.zeros([s1.shape[0], s1.shape[1] + add, s1.shape[2]])
953
1419
  for k in range(s1.shape[0]):
954
1420
  for i in range(s1.shape[2]):
955
- so1[k,:,i]=self.model(s1[k,:,i],dx=4,dell=1,add=add)
956
- so1[k,np.isnan(so1[k,:,i]),i]=0.0
1421
+ so1[k, :, i] = self.model(s1[k, :, i], dx=4, dell=1, add=add)
1422
+ so1[k, np.isnan(so1[k, :, i]), i] = 0.0
957
1423
  return so1
958
1424
 
959
- def extrapol(self,add):
960
- return scat(self.extrapol_s1(self.P00,add), \
961
- self.S0, \
962
- self.extrapol_s1(self.S1,add), \
963
- self.extrapol_s2(add,lnorm=1), \
964
- self.extrapol_s2(add,lnorm=2),self.j1,self.j2,backend=self.backend)
965
-
966
-
967
-
968
-
969
-
1425
+ def extrapol(self, add):
1426
+ return scat(
1427
+ self.extrapol_s1(self.P00, add),
1428
+ self.S0,
1429
+ self.extrapol_s1(self.S1, add),
1430
+ self.extrapol_s2(add, lnorm=1),
1431
+ self.extrapol_s2(add, lnorm=2),
1432
+ self.j1,
1433
+ self.j2,
1434
+ backend=self.backend,
1435
+ )
1436
+
1437
+
970
1438
  class funct(FOC.FoCUS):
971
-
972
- def fill(self,im,nullval=hp.UNSEEN):
973
- return self.fill_healpy(im,nullval=nullval)
974
-
975
- def moments(self,list_scat):
976
- S0=None
1439
+
1440
+ def fill(self, im, nullval=hp.UNSEEN):
1441
+ return self.fill_healpy(im, nullval=nullval)
1442
+
1443
+ def moments(self, list_scat):
1444
+ S0 = None
977
1445
  for k in list_scat:
978
- tmp=list_scat[k]
979
- nS0=np.expand_dims(tmp.S0.numpy(),0)
980
- nP00=np.expand_dims(tmp.P00.numpy(),0)
981
- nS1=np.expand_dims(tmp.S1.numpy(),0)
982
- nS2=np.expand_dims(tmp.S2.numpy(),0)
983
- nS2L=np.expand_dims(tmp.S2L.numpy(),0)
984
-
1446
+ tmp = list_scat[k]
1447
+ nS0 = np.expand_dims(tmp.S0.numpy(), 0)
1448
+ nP00 = np.expand_dims(tmp.P00.numpy(), 0)
1449
+ nS1 = np.expand_dims(tmp.S1.numpy(), 0)
1450
+ nS2 = np.expand_dims(tmp.S2.numpy(), 0)
1451
+ nS2L = np.expand_dims(tmp.S2L.numpy(), 0)
1452
+
985
1453
  if S0 is None:
986
- S0=nS0
987
- P00=nP00
988
- S1=nS1
989
- S2=nS2
990
- S2L=nS2L
1454
+ S0 = nS0
1455
+ P00 = nP00
1456
+ S1 = nS1
1457
+ S2 = nS2
1458
+ S2L = nS2L
991
1459
  else:
992
- S0=np.concatenate([S0,nS0],0)
993
- P00=np.concatenate([P00,nP00],0)
994
- S1=np.concatenate([S1,nS1],0)
995
- S2=np.concatenate([S2,nS2],0)
996
- S2L=np.concatenate([S2L,nS2L],0)
997
-
998
- sS0=np.std(S0,0)
999
- sP00=np.std(P00,0)
1000
- sS1=np.std(S1,0)
1001
- sS2=np.std(S2,0)
1002
- sS2L=np.std(S2L,0)
1003
-
1004
- mS0=np.mean(S0,0)
1005
- mP00=np.mean(P00,0)
1006
- mS1=np.mean(S1,0)
1007
- mS2=np.mean(S2,0)
1008
- mS2L=np.mean(S2L,0)
1009
-
1010
- return scat(mP00,mS0,mS1,mS2,mS2L,tmp.j1,tmp.j2,backend=self.backend), \
1011
- scat(sP00,sS0,sS1,sS2,sS2L,tmp.j1,tmp.j2,backend=self.backend)
1012
-
1013
- def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,calc_var=False):
1460
+ S0 = np.concatenate([S0, nS0], 0)
1461
+ P00 = np.concatenate([P00, nP00], 0)
1462
+ S1 = np.concatenate([S1, nS1], 0)
1463
+ S2 = np.concatenate([S2, nS2], 0)
1464
+ S2L = np.concatenate([S2L, nS2L], 0)
1465
+
1466
+ sS0 = np.std(S0, 0)
1467
+ sP00 = np.std(P00, 0)
1468
+ sS1 = np.std(S1, 0)
1469
+ sS2 = np.std(S2, 0)
1470
+ sS2L = np.std(S2L, 0)
1471
+
1472
+ mS0 = np.mean(S0, 0)
1473
+ mP00 = np.mean(P00, 0)
1474
+ mS1 = np.mean(S1, 0)
1475
+ mS2 = np.mean(S2, 0)
1476
+ mS2L = np.mean(S2L, 0)
1477
+
1478
+ return scat(
1479
+ mP00, mS0, mS1, mS2, mS2L, tmp.j1, tmp.j2, backend=self.backend
1480
+ ), scat(sP00, sS0, sS1, sS2, sS2L, tmp.j1, tmp.j2, backend=self.backend)
1481
+
1482
+ def eval(
1483
+ self,
1484
+ image1,
1485
+ image2=None,
1486
+ mask=None,
1487
+ Auto=True,
1488
+ s0_off=1e-6,
1489
+ calc_var=False,
1490
+ norm=None,
1491
+ ):
1014
1492
  # Check input consistency
1015
1493
  if image2 is not None:
1016
- if list(image1.shape)!=list(image2.shape):
1017
- print('The two input image should have the same size to eval Scattering')
1018
-
1019
- exit(0)
1494
+ if list(image1.shape) != list(image2.shape):
1495
+ print(
1496
+ "The two input image should have the same size to eval Scattering"
1497
+ )
1498
+
1499
+ return None
1020
1500
  if mask is not None:
1021
- if list(image1.shape)!=list(mask.shape)[1:]:
1022
- print('The mask should have the same size than the input image to eval Scattering')
1023
- print(image1.shape,mask.shape)
1024
- exit(0)
1025
- if self.use_2D and len(image1.shape)<2:
1026
- print('To work with 2D scattering transform, two dimension is needed, input map has only on dimension')
1027
- exit(0)
1028
-
1029
-
1501
+ if list(image1.shape) != list(mask.shape)[1:]:
1502
+ print(
1503
+ "The mask should have the same size than the input image to eval Scattering"
1504
+ )
1505
+ print("Image shape ", image1.shape, "Mask shape ", mask.shape)
1506
+ return None
1507
+ if self.use_2D and len(image1.shape) < 2:
1508
+ print(
1509
+ "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
1510
+ )
1511
+ return None
1512
+
1030
1513
  ### AUTO OR CROSS
1031
1514
  cross = False
1032
1515
  if image2 is not None:
1033
1516
  cross = True
1034
- all_cross=not Auto
1035
- else:
1036
- all_cross=False
1037
-
1517
+
1038
1518
  # Check if image1 is [Npix] or [Nbatch,Npix]
1039
- axis=1
1040
-
1519
+ axis = 1
1520
+
1041
1521
  # determine jmax and nside corresponding to the input map
1042
1522
  im_shape = image1.shape
1043
1523
  if self.use_2D:
1044
- if len(image1.shape)==2:
1045
- nside=np.min([im_shape[0],im_shape[1]])
1046
- npix = im_shape[0]*im_shape[1] # Number of pixels
1047
- x1=im_shape[0]
1048
- x2=im_shape[1]
1524
+ if len(image1.shape) == 2:
1525
+ nside = np.min([im_shape[0], im_shape[1]])
1526
+ npix = im_shape[0] * im_shape[1] # Number of pixels
1049
1527
  else:
1050
- nside=np.min([im_shape[1],im_shape[2]])
1051
- npix = im_shape[1]*im_shape[2] # Number of pixels
1052
- x1=im_shape[1]
1053
- x2=im_shape[2]
1054
- jmax = int(np.log(nside-self.KERNELSZ) / np.log(2)) # Number of j scales
1528
+ nside = np.min([im_shape[1], im_shape[2]])
1529
+ npix = im_shape[1] * im_shape[2] # Number of pixels
1530
+ jmax = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
1055
1531
  else:
1056
- if len(image1.shape)==2:
1532
+ if len(image1.shape) == 2:
1057
1533
  npix = int(im_shape[1]) # Number of pixels
1058
1534
  else:
1059
1535
  npix = int(im_shape[0]) # Number of pixels
1060
1536
 
1061
- nside=int(np.sqrt(npix//12))
1062
-
1063
- jmax=int(np.log(nside)/np.log(2)) #-self.OSTEP
1537
+ nside = int(np.sqrt(npix // 12))
1538
+
1539
+ jmax = int(np.log(nside) / np.log(2)) # -self.OSTEP
1064
1540
 
1065
1541
  ### LOCAL VARIABLES (IMAGES and MASK)
1066
1542
  # Check if image1 is [Npix] or [Nbatch,Npix]
1067
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1543
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1068
1544
  # image1 is [Nbatch, Npix]
1069
- I1 = self.backend.bk_cast(self.backend.bk_expand_dims(image1,0)) # Local image1 [Nbatch, Npix]
1545
+ I1 = self.backend.bk_cast(
1546
+ self.backend.bk_expand_dims(image1, 0)
1547
+ ) # Local image1 [Nbatch, Npix]
1070
1548
  if cross:
1071
- I2 = self.backend.bk_cast(self.backend.bk_expand_dims(image2,0)) # Local image2 [Nbatch, Npix]
1549
+ I2 = self.backend.bk_cast(
1550
+ self.backend.bk_expand_dims(image2, 0)
1551
+ ) # Local image2 [Nbatch, Npix]
1072
1552
  else:
1073
- I1=self.backend.bk_cast(image1)
1553
+ I1 = self.backend.bk_cast(image1)
1074
1554
  if cross:
1075
- I2=self.backend.bk_cast(image2)
1076
-
1555
+ I2 = self.backend.bk_cast(image2)
1556
+
1077
1557
  # self.mask is [Nmask, Npix]
1078
1558
  if mask is None:
1079
1559
  if self.use_2D:
1080
- vmask = self.backend.bk_ones([1, I1.shape[axis], I1.shape[axis+1]],dtype=self.all_type)
1560
+ vmask = self.backend.bk_ones(
1561
+ [1, I1.shape[axis], I1.shape[axis + 1]], dtype=self.all_type
1562
+ )
1081
1563
  else:
1082
1564
  vmask = self.backend.bk_ones([1, I1.shape[axis]], dtype=self.all_type)
1083
1565
  else:
1084
1566
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
1085
1567
 
1086
- if self.KERNELSZ>3:
1087
- if self.KERNELSZ==5:
1568
+ if self.KERNELSZ > 3:
1569
+ if self.KERNELSZ == 5:
1088
1570
  # if the kernel size is bigger than 3 increase the binning before smoothing
1089
1571
  if self.use_2D:
1090
- print(axis,image1.shape)
1091
- l_image1=self.up_grade(I1,I1.shape[axis]*2,axis=axis,nouty=I1.shape[axis+1]*2)
1092
- vmask=self.up_grade(vmask,I1.shape[axis]*2,axis=1,nouty=I1.shape[axis+1]*2)
1572
+ l_image1 = self.up_grade(
1573
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
1574
+ )
1575
+ vmask = self.up_grade(
1576
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
1577
+ )
1093
1578
  else:
1094
- l_image1=self.up_grade(I1,nside*2,axis=axis)
1095
- vmask=self.up_grade(vmask,nside*2,axis=1)
1096
-
1579
+ l_image1 = self.up_grade(I1, nside * 2, axis=axis)
1580
+ vmask = self.up_grade(vmask, nside * 2, axis=1)
1581
+
1097
1582
  if cross:
1098
1583
  if self.use_2D:
1099
- l_image2=self.up_grade(I2,I2.shape[axis]*2,axis=axis,nouty=I2.shape[axis+1]*2)
1584
+ l_image2 = self.up_grade(
1585
+ I2,
1586
+ I2.shape[axis] * 2,
1587
+ axis=axis,
1588
+ nouty=I2.shape[axis + 1] * 2,
1589
+ )
1100
1590
  else:
1101
- l_image2=self.up_grade(I2,nside*2,axis=axis)
1591
+ l_image2 = self.up_grade(I2, nside * 2, axis=axis)
1102
1592
  else:
1103
1593
  # if the kernel size is bigger than 3 increase the binning before smoothing
1104
1594
  if self.use_2D:
1105
- print(axis,image1.shape)
1106
- l_image1=self.up_grade(l_image1,I1.shape[axis]*4,axis=axis,nouty=I1.shape[axis+1]*4)
1107
- vmask=self.up_grade(vmask,I1.shape[axis]*4,axis=1,nouty=I1.shape[axis+1]*4)
1595
+ l_image1 = self.up_grade(
1596
+ l_image1,
1597
+ I1.shape[axis] * 4,
1598
+ axis=axis,
1599
+ nouty=I1.shape[axis + 1] * 4,
1600
+ )
1601
+ vmask = self.up_grade(
1602
+ vmask, I1.shape[axis] * 4, axis=1, nouty=I1.shape[axis + 1] * 4
1603
+ )
1108
1604
  else:
1109
- l_image1=self.up_grade(l_image1,nside*4,axis=axis)
1110
- vmask=self.up_grade(vmask,nside*4,axis=1)
1111
-
1605
+ l_image1 = self.up_grade(l_image1, nside * 4, axis=axis)
1606
+ vmask = self.up_grade(vmask, nside * 4, axis=1)
1607
+
1112
1608
  if cross:
1113
1609
  if self.use_2D:
1114
- l_image2=self.up_grade(l_image2,I2.shape[axis]*4,axis=axis,nouty=I2.shape[axis+1]*4)
1610
+ l_image2 = self.up_grade(
1611
+ l_image2,
1612
+ I2.shape[axis] * 4,
1613
+ axis=axis,
1614
+ nouty=I2.shape[axis + 1] * 4,
1615
+ )
1115
1616
  else:
1116
- l_image2=self.up_grade(l_image2,nside*4,axis=axis)
1617
+ l_image2 = self.up_grade(l_image2, nside * 4, axis=axis)
1117
1618
  else:
1118
- l_image1=I1
1619
+ l_image1 = I1
1119
1620
  if cross:
1120
- l_image2=I2
1621
+ l_image2 = I2
1121
1622
 
1122
1623
  if calc_var:
1123
- s0,vs0 = self.masked_mean(l_image1,vmask,axis=axis,calc_var=True)
1124
- s0=s0+s0_off
1624
+ s0, vs0 = self.masked_mean(l_image1, vmask, axis=axis, calc_var=True)
1625
+ s0 = s0 + s0_off
1125
1626
  else:
1126
- s0 = self.masked_mean(l_image1,vmask,axis=axis)+s0_off
1127
-
1128
- if cross and Auto==False:
1627
+ s0 = self.masked_mean(l_image1, vmask, axis=axis) + s0_off
1628
+
1629
+ if cross and not Auto:
1129
1630
  if calc_var:
1130
- s02,vs02=self.masked_mean(l_image2,vmask,axis=axis,calc_var=True)
1631
+ s02, vs02 = self.masked_mean(l_image2, vmask, axis=axis, calc_var=True)
1131
1632
  else:
1132
- s02=self.masked_mean(l_image2,vmask,axis=axis)
1133
-
1134
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1135
- if s0.dtype!='complex64' and s0.dtype!='complex128':
1136
- s0 = self.backend.bk_complex(s0,s02+s0_off)
1633
+ s02 = self.masked_mean(l_image2, vmask, axis=axis)
1634
+
1635
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1636
+ if self.backend.bk_is_complex(s0):
1637
+ s0 = self.backend.bk_complex(s0, s02 + s0_off)
1137
1638
  if calc_var:
1138
- vs0 = self.backend.bk_complex(vs0,vs02)
1639
+ vs0 = self.backend.bk_complex(vs0, vs02)
1139
1640
  else:
1140
- s0 = self.backend.bk_concat([s0,s02],axis=0)
1641
+ s0 = self.backend.bk_concat([s0, s02], axis=0)
1141
1642
  if calc_var:
1142
- vs0 = self.backend.bk_concat([vs0,vs02],axis=0)
1643
+ vs0 = self.backend.bk_concat([vs0, vs02], axis=0)
1143
1644
  else:
1144
- if s0.dtype!='complex64' and s0.dtype!='complex128':
1145
- s0 = self.backend.bk_complex(s0,s02+s0_off)
1645
+ if self.backend.bk_is_complex(s0):
1646
+ s0 = self.backend.bk_complex(s0, s02 + s0_off)
1146
1647
  if calc_var:
1147
- vs0 = self.backend.bk_complex(vs0,vs02)
1648
+ vs0 = self.backend.bk_complex(vs0, vs02)
1148
1649
  else:
1149
- s0 = self.backend.bk_concat([s0,s02],axis=0)
1650
+ s0 = self.backend.bk_concat([s0, s02], axis=0)
1150
1651
  if calc_var:
1151
- vs0 = self.backend.bk_concat([vs0,vs02],axis=0)
1152
-
1153
- s1=None
1154
- s2=None
1155
- s2l=None
1156
- p00=None
1157
- s2j1=None
1158
- s2j2=None
1159
-
1160
- l2_image=None
1161
- l2_image_imag=None
1162
-
1652
+ vs0 = self.backend.bk_concat([vs0, vs02], axis=0)
1653
+
1654
+ s1 = None
1655
+ s2 = None
1656
+ s2l = None
1657
+ p00 = None
1658
+ s2j1 = None
1659
+ s2j2 = None
1660
+ l2_image = None
1163
1661
  for j1 in range(jmax):
1164
- if j1<jmax-self.OSTEP: # stop to add scales
1662
+ if j1 < jmax - self.OSTEP: # stop to add scales
1165
1663
  # Convol image along the axis defined by 'axis' using the wavelet defined at
1166
1664
  # the foscat initialisation
1167
- #c_image_real is [....,Npix_j1,....,Norient]
1168
- c_image1=self.convol(l_image1,axis=axis)
1665
+ # c_image_real is [....,Npix_j1,....,Norient]
1666
+ c_image1 = self.convol(l_image1, axis=axis)
1169
1667
  if cross:
1170
- c_image2=self.convol(l_image2,axis=axis)
1668
+ c_image2 = self.convol(l_image2, axis=axis)
1171
1669
  else:
1172
- c_image2=c_image1
1670
+ c_image2 = c_image1
1173
1671
 
1174
1672
  # Compute (a+ib)*(a+ib)* the last c_image column is the real and imaginary part
1175
- conj=c_image1*self.backend.bk_conjugate(c_image2)
1176
-
1673
+ conj = c_image1 * self.backend.bk_conjugate(c_image2)
1674
+
1177
1675
  if Auto:
1178
- conj=self.backend.bk_real(conj)
1676
+ conj = self.backend.bk_real(conj)
1179
1677
 
1180
1678
  # Compute l_p00 [....,....,Nmask,j1,Norient]
1181
1679
  if calc_var:
1182
- l_p00,l_vp00 = self.masked_mean(conj,vmask,axis=axis,rank=j1,calc_var=True)
1183
- l_p00 = self.backend.bk_expand_dims(l_p00,-2)
1184
- l_vp00 = self.backend.bk_expand_dims(l_vp00,-2)
1680
+ l_p00, l_vp00 = self.masked_mean(
1681
+ conj, vmask, axis=axis, rank=j1, calc_var=True
1682
+ )
1683
+ l_p00 = self.backend.bk_expand_dims(l_p00, -2)
1684
+ l_vp00 = self.backend.bk_expand_dims(l_vp00, -2)
1185
1685
  else:
1186
- l_p00 = self.masked_mean(conj,vmask,axis=axis,rank=j1)
1187
- l_p00 = self.backend.bk_expand_dims(l_p00,-2)
1686
+ l_p00 = self.masked_mean(conj, vmask, axis=axis, rank=j1)
1687
+ l_p00 = self.backend.bk_expand_dims(l_p00, -2)
1188
1688
 
1189
- conj=self.backend.bk_L1(conj)
1689
+ conj = self.backend.bk_L1(conj)
1190
1690
 
1191
- # Compute l_s1 [....,....,Nmask,1,Norient]
1691
+ # Compute l_s1 [....,....,Nmask,1,Norient]
1192
1692
  if calc_var:
1193
- l_s1,l_vs1 = self.masked_mean(conj,vmask,axis=axis,rank=j1,calc_var=True)
1194
- l_s1 =self.backend.bk_expand_dims(l_s1,-2)
1195
- l_vs1 =self.backend.bk_expand_dims(l_vs1,-2)
1693
+ l_s1, l_vs1 = self.masked_mean(
1694
+ conj, vmask, axis=axis, rank=j1, calc_var=True
1695
+ )
1696
+ l_s1 = self.backend.bk_expand_dims(l_s1, -2)
1697
+ l_vs1 = self.backend.bk_expand_dims(l_vs1, -2)
1196
1698
  else:
1197
- l_s1 = self.backend.bk_expand_dims(self.masked_mean(conj,vmask,axis=axis,rank=j1),-2)
1699
+ l_s1 = self.backend.bk_expand_dims(
1700
+ self.masked_mean(conj, vmask, axis=axis, rank=j1), -2
1701
+ )
1198
1702
 
1199
- # Concat S1,P00 [....,....,Nmask,j1,Norient]
1703
+ # Concat S1,P00 [....,....,Nmask,j1,Norient]
1200
1704
  if s1 is None:
1201
- s1=l_s1
1202
- p00=l_p00
1705
+ s1 = l_s1
1706
+ p00 = l_p00
1203
1707
  if calc_var:
1204
- vs1=l_vs1
1205
- vp00=l_vp00
1708
+ vs1 = l_vs1
1709
+ vp00 = l_vp00
1206
1710
  else:
1207
- s1=self.backend.bk_concat([s1,l_s1],axis=-2)
1208
- p00=self.backend.bk_concat([p00,l_p00],axis=-2)
1711
+ s1 = self.backend.bk_concat([s1, l_s1], axis=-2)
1712
+ p00 = self.backend.bk_concat([p00, l_p00], axis=-2)
1209
1713
  if calc_var:
1210
- vs1=self.backend.bk_concat([vs1,l_vs1],axis=-2)
1211
- vp00=self.backend.bk_concat([vp00,l_vp00],axis=-2)
1714
+ vs1 = self.backend.bk_concat([vs1, l_vs1], axis=-2)
1715
+ vp00 = self.backend.bk_concat([vp00, l_vp00], axis=-2)
1212
1716
 
1213
1717
  # Concat l2_image [....,j1,Npix_j1,,....,Norient]
1214
1718
  if l2_image is None:
1215
1719
  if self.use_2D:
1216
- l2_image=self.backend.bk_expand_dims(conj,axis=-4)
1720
+ l2_image = self.backend.bk_expand_dims(conj, axis=-4)
1217
1721
  else:
1218
- l2_image=self.backend.bk_expand_dims(conj,axis=-3)
1722
+ l2_image = self.backend.bk_expand_dims(conj, axis=-3)
1219
1723
  else:
1220
1724
  if self.use_2D:
1221
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=-4),l2_image],axis=-4)
1725
+ l2_image = self.backend.bk_concat(
1726
+ [self.backend.bk_expand_dims(conj, axis=-4), l2_image],
1727
+ axis=-4,
1728
+ )
1222
1729
  else:
1223
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=-3),l2_image],axis=-3)
1730
+ l2_image = self.backend.bk_concat(
1731
+ [self.backend.bk_expand_dims(conj, axis=-3), l2_image],
1732
+ axis=-3,
1733
+ )
1224
1734
 
1225
1735
  # Convol l2_image [....,Npix_j1,j1,....,Norient,Norient]
1226
- c2_image=self.convol(self.backend.bk_relu(l2_image),axis=axis+1)
1736
+ c2_image = self.convol(self.backend.bk_relu(l2_image), axis=axis + 1)
1227
1737
 
1228
- conj2p=c2_image*self.backend.bk_conjugate(c2_image)
1229
- conj2pl1=self.backend.bk_L1(conj2p)
1738
+ conj2p = c2_image * self.backend.bk_conjugate(c2_image)
1739
+ conj2pl1 = self.backend.bk_L1(conj2p)
1230
1740
 
1231
1741
  if Auto:
1232
- conj2p=self.backend.bk_real(conj2p)
1233
- conj2pl1=self.backend.bk_real(conj2pl1)
1742
+ conj2p = self.backend.bk_real(conj2p)
1743
+ conj2pl1 = self.backend.bk_real(conj2pl1)
1234
1744
 
1235
- c2_image=self.convol(self.backend.bk_relu(-l2_image),axis=axis+1)
1745
+ c2_image = self.convol(self.backend.bk_relu(-l2_image), axis=axis + 1)
1236
1746
 
1237
- conj2m=c2_image*self.backend.bk_conjugate(c2_image)
1238
- conj2ml1=self.backend.bk_L1(conj2m)
1747
+ conj2m = c2_image * self.backend.bk_conjugate(c2_image)
1748
+ conj2ml1 = self.backend.bk_L1(conj2m)
1239
1749
 
1240
1750
  if Auto:
1241
- conj2m=self.backend.bk_real(conj2m)
1242
- conj2ml1=self.backend.bk_real(conj2ml1)
1243
-
1751
+ conj2m = self.backend.bk_real(conj2m)
1752
+ conj2ml1 = self.backend.bk_real(conj2ml1)
1753
+
1244
1754
  # Convol l_s2 [....,....,Nmask,j1,Norient,Norient]
1245
1755
  if calc_var:
1246
- l_s2,l_vs2 = self.masked_mean(conj2p-conj2m,vmask,axis=axis+1,rank=j1,calc_var=True)
1247
- l_s2l1,l_vs2l1 = self.masked_mean(conj2pl1-conj2ml1,vmask,axis=axis+1,rank=j1,calc_var=True)
1756
+ l_s2, l_vs2 = self.masked_mean(
1757
+ conj2p - conj2m, vmask, axis=axis + 1, rank=j1, calc_var=True
1758
+ )
1759
+ l_s2l1, l_vs2l1 = self.masked_mean(
1760
+ conj2pl1 - conj2ml1, vmask, axis=axis + 1, rank=j1, calc_var=True
1761
+ )
1248
1762
  else:
1249
- l_s2 = self.masked_mean(conj2p-conj2m,vmask,axis=axis+1,rank=j1)
1250
- l_s2l1 = self.masked_mean(conj2pl1-conj2ml1,vmask,axis=axis+1,rank=j1)
1763
+ l_s2 = self.masked_mean(conj2p - conj2m, vmask, axis=axis + 1, rank=j1)
1764
+ l_s2l1 = self.masked_mean(
1765
+ conj2pl1 - conj2ml1, vmask, axis=axis + 1, rank=j1
1766
+ )
1251
1767
 
1252
1768
  # Concat l_s2 [....,....,Nmask,j1*(j1+1)/2,Norient,Norient]
1253
1769
  if s2 is None:
1254
- s2l=l_s2
1255
- s2=l_s2l1
1770
+ s2l = l_s2
1771
+ s2 = l_s2l1
1256
1772
  if calc_var:
1257
- vs2l=l_vs2
1258
- vs2=l_vs2l1
1259
-
1260
- s2j1=np.arange(l_s2.shape[axis+1],dtype='int')
1261
- s2j2=j1*np.ones(l_s2.shape[axis+1],dtype='int')
1773
+ vs2l = l_vs2
1774
+ vs2 = l_vs2l1
1775
+
1776
+ s2j1 = np.arange(l_s2.shape[axis + 1], dtype="int")
1777
+ s2j2 = j1 * np.ones(l_s2.shape[axis + 1], dtype="int")
1262
1778
  else:
1263
- s2=self.backend.bk_concat([s2,l_s2l1],axis=-3)
1264
- s2l=self.backend.bk_concat([s2l,l_s2],axis=-3)
1779
+ s2 = self.backend.bk_concat([s2, l_s2l1], axis=-3)
1780
+ s2l = self.backend.bk_concat([s2l, l_s2], axis=-3)
1265
1781
  if calc_var:
1266
- vs2=self.backend.bk_concat([vs2,l_vs2l1],axis=-3)
1267
- vs2l=self.backend.bk_concat([vs2l,l_vs2],axis=-3)
1268
-
1269
- s2j1=np.concatenate([s2j1,np.arange(l_s2.shape[axis+1],dtype='int')],0)
1270
- s2j2=np.concatenate([s2j2,j1*np.ones(l_s2.shape[axis+1],dtype='int')],0)
1271
-
1272
- if j1!=jmax-1:
1273
- # Rescale vmask [Nmask,Npix_j1//4]
1274
- vmask = self.smooth(vmask,axis=1)
1275
- vmask = self.ud_grade_2(vmask,axis=1)
1782
+ vs2 = self.backend.bk_concat([vs2, l_vs2l1], axis=-3)
1783
+ vs2l = self.backend.bk_concat([vs2l, l_vs2], axis=-3)
1784
+
1785
+ s2j1 = np.concatenate(
1786
+ [s2j1, np.arange(l_s2.shape[axis + 1], dtype="int")], 0
1787
+ )
1788
+ s2j2 = np.concatenate(
1789
+ [s2j2, j1 * np.ones(l_s2.shape[axis + 1], dtype="int")], 0
1790
+ )
1791
+
1792
+ if j1 != jmax - 1:
1793
+ # Rescale vmask [Nmask,Npix_j1//4]
1794
+ vmask = self.smooth(vmask, axis=1)
1795
+ vmask = self.ud_grade_2(vmask, axis=1)
1276
1796
  if self.mask_thres is not None:
1277
- vmask = self.backend.bk_threshold(vmask,self.mask_thres)
1797
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
1278
1798
 
1279
- # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1280
- l2_image = self.smooth(l2_image,axis=axis+1)
1281
- l2_image = self.ud_grade_2(l2_image,axis=axis+1)
1799
+ # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1800
+ l2_image = self.smooth(l2_image, axis=axis + 1)
1801
+ l2_image = self.ud_grade_2(l2_image, axis=axis + 1)
1282
1802
 
1283
- # Rescale l_image [....,Npix_j1//4,....]
1284
- l_image1 = self.smooth(l_image1,axis=axis)
1285
- l_image1 = self.ud_grade_2(l_image1,axis=axis)
1803
+ # Rescale l_image [....,Npix_j1//4,....]
1804
+ l_image1 = self.smooth(l_image1, axis=axis)
1805
+ l_image1 = self.ud_grade_2(l_image1, axis=axis)
1286
1806
  if cross:
1287
- l_image2 = self.smooth(l_image2,axis=axis)
1288
- l_image2 = self.ud_grade_2(l_image2,axis=axis)
1289
-
1290
-
1291
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1292
- sc_ret=scat(p00[0],s0[0],s1[0],s2[0],s2l[0],s2j1,s2j2,cross=cross,backend=self.backend)
1807
+ l_image2 = self.smooth(l_image2, axis=axis)
1808
+ l_image2 = self.ud_grade_2(l_image2, axis=axis)
1809
+
1810
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1811
+ sc_ret = scat(
1812
+ p00[0],
1813
+ s0[0],
1814
+ s1[0],
1815
+ s2[0],
1816
+ s2l[0],
1817
+ s2j1,
1818
+ s2j2,
1819
+ cross=cross,
1820
+ backend=self.backend,
1821
+ )
1293
1822
  else:
1294
- sc_ret=scat(p00,s0,s1,s2,s2l,s2j1,s2j2,cross=cross,backend=self.backend)
1295
-
1823
+ sc_ret = scat(
1824
+ p00, s0, s1, s2, s2l, s2j1, s2j2, cross=cross, backend=self.backend
1825
+ )
1826
+
1296
1827
  if calc_var:
1297
- if len(image1.shape)==1 or (len(image1.shape)==2 and self.use_2D):
1298
- vsc_ret=scat(vp00[0],vs0[0],vs1[0],vs2[0],vs2l[0],s2j1,s2j2,cross=cross,backend=self.backend)
1828
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
1829
+ vsc_ret = scat(
1830
+ vp00[0],
1831
+ vs0[0],
1832
+ vs1[0],
1833
+ vs2[0],
1834
+ vs2l[0],
1835
+ s2j1,
1836
+ s2j2,
1837
+ cross=cross,
1838
+ backend=self.backend,
1839
+ )
1299
1840
  else:
1300
- vsc_ret=scat(vp00,vs0,vs1,vs2,vs2l,s2j1,s2j2,cross=cross,backend=self.backend)
1301
- return sc_ret,vsc_ret
1841
+ vsc_ret = scat(
1842
+ vp00,
1843
+ vs0,
1844
+ vs1,
1845
+ vs2,
1846
+ vs2l,
1847
+ s2j1,
1848
+ s2j2,
1849
+ cross=cross,
1850
+ backend=self.backend,
1851
+ )
1852
+ return sc_ret, vsc_ret
1302
1853
  else:
1303
1854
  return sc_ret
1304
1855
 
1305
- def square(self,x):
1856
+ def square(self, x):
1306
1857
  # the abs make the complex value usable for reduce_sum or mean
1307
- return scat(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1308
- self.backend.bk_square(self.backend.bk_abs(x.S0)),
1309
- self.backend.bk_square(self.backend.bk_abs(x.S1)),
1310
- self.backend.bk_square(self.backend.bk_abs(x.S2)),
1311
- self.backend.bk_square(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1312
-
1313
- def sqrt(self,x):
1858
+ return scat(
1859
+ self.backend.bk_square(self.backend.bk_abs(x.P00)),
1860
+ self.backend.bk_square(self.backend.bk_abs(x.S0)),
1861
+ self.backend.bk_square(self.backend.bk_abs(x.S1)),
1862
+ self.backend.bk_square(self.backend.bk_abs(x.S2)),
1863
+ self.backend.bk_square(self.backend.bk_abs(x.S2L)),
1864
+ x.j1,
1865
+ x.j2,
1866
+ backend=self.backend,
1867
+ )
1868
+
1869
+ def sqrt(self, x):
1314
1870
  # the abs make the complex value usable for reduce_sum or mean
1315
- return scat(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1316
- self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1317
- self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1318
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1319
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1871
+ return scat(
1872
+ self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1873
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1874
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1875
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1876
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),
1877
+ x.j1,
1878
+ x.j2,
1879
+ backend=self.backend,
1880
+ )
1881
+
1882
+ def reduce_distance(self, x, y, sigma=None):
1883
+
1884
+ if isinstance(x, scat):
1885
+ if sigma is None:
1886
+ result = self.diff_data(y.S0, x.S0, is_complex=False)
1887
+ result += self.diff_data(y.S1, x.S1)
1888
+ result += self.diff_data(y.P00, x.P00)
1889
+ result += self.diff_data(y.S2, x.S2)
1890
+ result += self.diff_data(y.S2L, x.S2L)
1891
+ else:
1892
+ result = self.diff_data(y.S0, x.S0, is_complex=False, sigma=sigma.S0)
1893
+ result += self.diff_data(y.S1, x.S1, sigma=sigma.S1)
1894
+ result += self.diff_data(y.P00, x.P00, sigma=sigma.P00)
1895
+ result += self.diff_data(y.S2, x.S2, sigma=sigma.S2)
1896
+ result += self.diff_data(y.S2L, x.S2L, sigma=sigma.S2L)
1897
+
1898
+ nval = (
1899
+ self.backend.bk_size(x.S0)
1900
+ + self.backend.bk_size(x.P00)
1901
+ + self.backend.bk_size(x.S1)
1902
+ + self.backend.bk_size(x.S2)
1903
+ + self.backend.bk_size(x.S2L)
1904
+ )
1905
+
1906
+ result /= self.backend.bk_cast(nval)
1907
+ else:
1908
+ return self.backend.bk_reduce_sum(x)
1909
+ return result
1320
1910
 
1321
- def reduce_mean(self,x,axis=None):
1911
+ def reduce_mean(self, x, axis=None):
1322
1912
  if axis is None:
1323
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))+ \
1324
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))+ \
1325
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))+ \
1326
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))+ \
1327
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1328
-
1329
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1330
- np.array(list(x.S0.shape)).prod()+ \
1331
- np.array(list(x.S1.shape)).prod()+ \
1332
- np.array(list(x.S2.shape)).prod()
1333
-
1334
- return tmp/ntmp
1913
+ tmp = (
1914
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))
1915
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))
1916
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))
1917
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))
1918
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1919
+ )
1920
+
1921
+ ntmp = (
1922
+ np.array(list(x.P00.shape)).prod()
1923
+ + np.array(list(x.S0.shape)).prod()
1924
+ + np.array(list(x.S1.shape)).prod()
1925
+ + np.array(list(x.S2.shape)).prod()
1926
+ )
1927
+
1928
+ return tmp / ntmp
1335
1929
  else:
1336
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00,axis=axis))+ \
1337
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0,axis=axis))+ \
1338
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1,axis=axis))+ \
1339
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2,axis=axis))+ \
1340
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L,axis=axis))
1341
-
1342
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1343
- np.array(list(x.S0.shape)).prod()+ \
1344
- np.array(list(x.S1.shape)).prod()+ \
1345
- np.array(list(x.S2.shape)).prod()+ \
1346
- np.array(list(x.S2L.shape)).prod()
1347
-
1348
- return tmp/ntmp
1349
-
1350
- def reduce_sum(self,x,axis=None):
1930
+ tmp = (
1931
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00, axis=axis))
1932
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0, axis=axis))
1933
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1, axis=axis))
1934
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2, axis=axis))
1935
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L, axis=axis))
1936
+ )
1937
+
1938
+ ntmp = (
1939
+ np.array(list(x.P00.shape)).prod()
1940
+ + np.array(list(x.S0.shape)).prod()
1941
+ + np.array(list(x.S1.shape)).prod()
1942
+ + np.array(list(x.S2.shape)).prod()
1943
+ + np.array(list(x.S2L.shape)).prod()
1944
+ )
1945
+
1946
+ return tmp / ntmp
1947
+
1948
+ def reduce_sum(self, x, axis=None):
1351
1949
  if axis is None:
1352
- return self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))+ \
1353
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))+ \
1354
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))+ \
1355
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))+ \
1356
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1950
+ return (
1951
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))
1952
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
1953
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
1954
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
1955
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1956
+ )
1357
1957
  else:
1358
- return scat(self.backend.bk_reduce_sum(x.P00,axis=axis),
1359
- self.backend.bk_reduce_sum(x.S0,axis=axis),
1360
- self.backend.bk_reduce_sum(x.S1,axis=axis),
1361
- self.backend.bk_reduce_sum(x.S2,axis=axis),
1362
- self.backend.bk_reduce_sum(x.S2L,axis=axis),x.j1,x.j2,backend=self.backend)
1363
-
1364
- def ldiff(self,sig,x):
1365
- return scat(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1366
- x.domult(sig.S0,x.S0)*x.domult(sig.S0,x.S0),
1367
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1368
- x.domult(sig.S2,x.S2)*x.domult(sig.S2,x.S2),
1369
- x.domult(sig.S2L,x.S2L)*x.domult(sig.S2L,x.S2L),x.j1,x.j2,backend=self.backend)
1370
-
1371
- def log(self,x):
1372
- return scat(self.backend.bk_log(x.P00),
1373
- self.backend.bk_log(x.S0),
1374
- self.backend.bk_log(x.S1),
1375
- self.backend.bk_log(x.S2),
1376
- self.backend.bk_log(x.S2L),x.j1,x.j2,backend=self.backend)
1377
- def abs(self,x):
1378
- return scat(self.backend.bk_abs(x.P00),
1379
- self.backend.bk_abs(x.S0),
1380
- self.backend.bk_abs(x.S1),
1381
- self.backend.bk_abs(x.S2),
1382
- self.backend.bk_abs(x.S2L),x.j1,x.j2,backend=self.backend)
1383
- def inv(self,x):
1384
- return scat(1/(x.P00),1/(x.S0),1/(x.S1),1/(x.S2),1/(x.S2L),x.j1,x.j2,backend=self.backend)
1958
+ return scat(
1959
+ self.backend.bk_reduce_sum(x.P00, axis=axis),
1960
+ self.backend.bk_reduce_sum(x.S0, axis=axis),
1961
+ self.backend.bk_reduce_sum(x.S1, axis=axis),
1962
+ self.backend.bk_reduce_sum(x.S2, axis=axis),
1963
+ self.backend.bk_reduce_sum(x.S2L, axis=axis),
1964
+ x.j1,
1965
+ x.j2,
1966
+ backend=self.backend,
1967
+ )
1968
+
1969
+ def ldiff(self, sig, x):
1970
+ return scat(
1971
+ x.domult(sig.P00, x.P00) * x.domult(sig.P00, x.P00),
1972
+ x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
1973
+ x.domult(sig.S1, x.S1) * x.domult(sig.S1, x.S1),
1974
+ x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
1975
+ x.domult(sig.S2L, x.S2L) * x.domult(sig.S2L, x.S2L),
1976
+ x.j1,
1977
+ x.j2,
1978
+ backend=self.backend,
1979
+ )
1980
+
1981
+ def log(self, x):
1982
+ return scat(
1983
+ self.backend.bk_log(x.P00),
1984
+ self.backend.bk_log(x.S0),
1985
+ self.backend.bk_log(x.S1),
1986
+ self.backend.bk_log(x.S2),
1987
+ self.backend.bk_log(x.S2L),
1988
+ x.j1,
1989
+ x.j2,
1990
+ backend=self.backend,
1991
+ )
1992
+
1993
+ def abs(self, x):
1994
+ return scat(
1995
+ self.backend.bk_abs(x.P00),
1996
+ self.backend.bk_abs(x.S0),
1997
+ self.backend.bk_abs(x.S1),
1998
+ self.backend.bk_abs(x.S2),
1999
+ self.backend.bk_abs(x.S2L),
2000
+ x.j1,
2001
+ x.j2,
2002
+ backend=self.backend,
2003
+ )
2004
+
2005
+ def inv(self, x):
2006
+ return scat(
2007
+ 1 / (x.P00),
2008
+ 1 / (x.S0),
2009
+ 1 / (x.S1),
2010
+ 1 / (x.S2),
2011
+ 1 / (x.S2L),
2012
+ x.j1,
2013
+ x.j2,
2014
+ backend=self.backend,
2015
+ )
1385
2016
 
1386
2017
  def one(self):
1387
- return scat(1.0,1.0,1.0,1.0,1.0,[0],[0],backend=self.backend)
2018
+ return scat(1.0, 1.0, 1.0, 1.0, 1.0, [0], [0], backend=self.backend)
1388
2019
 
1389
- @tf.function
1390
- def eval_comp_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
2020
+ @tf_function
2021
+ def eval_comp_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1391
2022
 
1392
- res=self.eval(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1393
- return res.P00,res.S0,res.S1,res.S2,res.S2L,res.j1,res.j2
2023
+ res = self.eval(image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off)
2024
+ return res.P00, res.S0, res.S1, res.S2, res.S2L, res.j1, res.j2
1394
2025
 
1395
- def eval_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1396
- p0,s0,s1,s2,s2l,j1,j2=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1397
- return scat(p0,s0,s1,s2,s2l,j1,j2,backend=self.backend)
1398
-
1399
-
1400
-
2026
+ def eval_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
2027
+ p0, s0, s1, s2, s2l, j1, j2 = self.eval_comp_fast(
2028
+ image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off
2029
+ )
2030
+ return scat(p0, s0, s1, s2, s2l, j1, j2, backend=self.backend)