foscat 3.1.6__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat1D.py CHANGED
@@ -1,453 +1,608 @@
1
- import foscat.FoCUS as FOC
2
- import numpy as np
3
1
  import pickle
4
- import foscat.backend as bk
5
2
  import sys
6
3
 
4
+ import numpy as np
5
+
6
+ import foscat.backend as bk
7
+ import foscat.FoCUS as FOC
8
+
7
9
  # Vérifier si TensorFlow est importé et défini
8
- tf_defined = 'tensorflow' in sys.modules
10
+ tf_defined = "tensorflow" in sys.modules
9
11
 
10
12
  if tf_defined:
11
13
  import tensorflow as tf
12
- tf_function = tf.function # Facultatif : si vous voulez utiliser TensorFlow dans ce script
14
+
15
+ tf_function = (
16
+ tf.function
17
+ ) # Facultatif : si vous voulez utiliser TensorFlow dans ce script
13
18
  else:
19
+
14
20
  def tf_function(func):
15
21
  return func
16
-
22
+
23
+
17
24
  def read(filename):
18
- thescat=scat1D(1,1,1,1,1,[0],[0])
25
+ thescat = scat1D(1, 1, 1, 1, 1, [0], [0])
19
26
  return thescat.read(filename)
20
-
27
+
28
+
21
29
  class scat1D:
22
- def __init__(self,p00,s0,s1,s2,s2l,j1,j2,cross=False,backend=None):
23
- self.P00=p00
24
- self.S0=s0
25
- self.S1=s1
26
- self.S2=s2
27
- self.S2L=s2l
28
- self.j1=j1
29
- self.j2=j2
30
- self.cross=cross
31
- self.backend=backend
30
+ def __init__(self, p00, s0, s1, s2, s2l, j1, j2, cross=False, backend=None):
31
+ self.P00 = p00
32
+ self.S0 = s0
33
+ self.S1 = s1
34
+ self.S2 = s2
35
+ self.S2L = s2l
36
+ self.j1 = j1
37
+ self.j2 = j2
38
+ self.cross = cross
39
+ self.backend = backend
32
40
 
33
41
  # ---------------------------------------------−---------
34
- def build_flat(self,table):
35
- shape=table.shape
36
- ndata=1
37
- for k in range(1,len(table.shape)):
38
- ndata=ndata*table.shape[k]
39
- return self.backend.bk_reshape(table,[table.shape[0],ndata])
40
-
42
+ def build_flat(self, table):
43
+ shape = table.shape
44
+ ndata = 1
45
+ for k in range(1, len(shape)):
46
+ ndata = ndata * shape[k]
47
+ return self.backend.bk_reshape(table, [shape[0], ndata])
48
+
41
49
  # ---------------------------------------------−---------
42
- def flatten(self,S2L=True,P00=True):
43
- tmp=[self.build_flat(self.S0),self.build_flat(self.S1)]
44
-
50
+ def flatten(self, S2L=True, P00=True):
51
+ tmp = [self.build_flat(self.S0), self.build_flat(self.S1)]
52
+
45
53
  if P00:
46
- tmp=tmp+[self.build_flat(self.P00)]
54
+ tmp = tmp + [self.build_flat(self.P00)]
55
+
56
+ tmp = tmp + [self.build_flat(self.S2)]
47
57
 
48
- tmp=tmp+[self.build_flat(self.S2)]
49
-
50
58
  if S2L:
51
- tmp=tmp+[self.build_flat(self.S2L)]
52
-
53
- if isinstance(self.P00,np.ndarray):
54
- return np.concatenate(tmp,1)
59
+ tmp = tmp + [self.build_flat(self.S2L)]
60
+
61
+ if isinstance(self.P00, np.ndarray):
62
+ return np.concatenate(tmp, 1)
55
63
  else:
56
- return self.backend.bk_concat(tmp,1)
64
+ return self.backend.bk_concat(tmp, 1)
65
+
57
66
  # ---------------------------------------------−---------
58
- def flatten_name(self,S2L=True,P00=True):
59
-
60
- tmp=['S0']
67
+ def flatten_name(self, S2L=True, P00=True):
68
+
69
+ tmp = ["S0"]
70
+
71
+ tmp = tmp + ["S1_%d" % (k) for k in range(self.S1.shape[-1])]
61
72
 
62
-
63
- tmp=tmp+['S1_%d'%(k) for k in range(self.S1.shape[-1])]
64
-
65
73
  if P00:
66
- tmp=tmp+['P00_%d'%(k) for k in range(self.P00.shape[-1])]
74
+ tmp = tmp + ["P00_%d" % (k) for k in range(self.P00.shape[-1])]
75
+
76
+ j1, j2 = self.get_j_idx()
77
+
78
+ tmp = tmp + ["S2_%d-%d" % (j1[k], j2[k]) for k in range(self.S2.shape[-1])]
67
79
 
68
- j1,j2=self.get_j_idx()
69
-
70
- tmp=tmp+['S2_%d-%d'%(j1[k],j2[k]) for k in range(self.S2.shape[-1])]
71
-
72
80
  if S2L:
73
- tmp=tmp+['S2L_%d-%d'%(j1[k],j2[k]) for k in range(self.S2.shape[-1])]
74
-
81
+ tmp = tmp + ["S2L_%d-%d" % (j1[k], j2[k]) for k in range(self.S2.shape[-1])]
82
+
75
83
  return tmp
76
-
84
+
77
85
  def get_j_idx(self):
78
- return self.j1,self.j2
79
-
86
+ return self.j1, self.j2
87
+
80
88
  def get_S0(self):
81
- return(self.S0)
89
+ return self.S0
82
90
 
83
91
  def get_S1(self):
84
- return(self.S1)
85
-
92
+ return self.S1
93
+
86
94
  def get_S2(self):
87
- return(self.S2)
95
+ return self.S2
88
96
 
89
97
  def get_S2L(self):
90
- return(self.S2L)
98
+ return self.S2L
91
99
 
92
100
  def get_P00(self):
93
- return(self.P00)
101
+ return self.P00
94
102
 
95
103
  def reset_P00(self):
96
- self.P00=0*self.P00
104
+ self.P00 = 0 * self.P00
97
105
 
98
106
  def constant(self):
99
- return scat1D(self.backend.constant(self.P00 ), \
100
- self.backend.constant(self.S0 ), \
101
- self.backend.constant(self.S1 ), \
102
- self.backend.constant(self.S2 ), \
103
- self.backend.constant(self.S2L ), \
104
- self.j1 , \
105
- self.j2 ,backend=self.backend)
106
-
107
- def domult(self,x,y):
108
- if x.dtype==y.dtype:
109
- return x*y
110
-
107
+ return scat1D(
108
+ self.backend.constant(self.P00),
109
+ self.backend.constant(self.S0),
110
+ self.backend.constant(self.S1),
111
+ self.backend.constant(self.S2),
112
+ self.backend.constant(self.S2L),
113
+ self.j1,
114
+ self.j2,
115
+ backend=self.backend,
116
+ )
117
+
118
+ def domult(self, x, y):
119
+ if x.dtype == y.dtype:
120
+ return x * y
121
+
111
122
  if self.backend.bk_is_complex(x):
112
-
113
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
123
+
124
+ return self.backend.bk_complex(
125
+ self.backend.bk_real(x) * y, self.backend.bk_imag(x) * y
126
+ )
114
127
  else:
115
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
116
-
117
- def dodiv(self,x,y):
118
- if x.dtype==y.dtype:
119
- return x/y
128
+ return self.backend.bk_complex(
129
+ self.backend.bk_real(y) * x, self.backend.bk_imag(y) * x
130
+ )
131
+
132
+ def dodiv(self, x, y):
133
+ if x.dtype == y.dtype:
134
+ return x / y
120
135
  if self.backend.bk_is_complex(x):
121
-
122
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
136
+
137
+ return self.backend.bk_complex(
138
+ self.backend.bk_real(x) / y, self.backend.bk_imag(x) / y
139
+ )
123
140
  else:
124
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
125
-
126
- def domin(self,x,y):
127
- if x.dtype==y.dtype:
128
- return x-y
129
-
141
+ return self.backend.bk_complex(
142
+ x / self.backend.bk_real(y), x / self.backend.bk_imag(y)
143
+ )
144
+
145
+ def domin(self, x, y):
146
+ if x.dtype == y.dtype:
147
+ return x - y
148
+
130
149
  if self.backend.bk_is_complex(x):
131
-
132
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
150
+
151
+ return self.backend.bk_complex(
152
+ self.backend.bk_real(x) - y, self.backend.bk_imag(x) - y
153
+ )
133
154
  else:
134
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
135
-
136
- def doadd(self,x,y):
137
- if x.dtype==y.dtype:
138
- return x+y
139
-
155
+ return self.backend.bk_complex(
156
+ x - self.backend.bk_real(y), x - self.backend.bk_imag(y)
157
+ )
158
+
159
+ def doadd(self, x, y):
160
+ if x.dtype == y.dtype:
161
+ return x + y
162
+
140
163
  if self.backend.bk_is_complex(x):
141
-
142
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
164
+
165
+ return self.backend.bk_complex(
166
+ self.backend.bk_real(x) + y, self.backend.bk_imag(x) + y
167
+ )
143
168
  else:
144
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
145
-
146
- def relu(self):
147
-
148
- return scat1D(self.backend.bk_relu(self.P00), \
149
- self.backend.bk_relu(self.S0), \
150
- self.backend.bk_relu(self.S1), \
151
- self.backend.bk_relu(self.S2), \
152
- self.backend.bk_relu(self.S2L), \
153
- self.j1,self.j2,backend=self.backend)
154
-
155
- def __add__(self,other):
156
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
157
- isinstance(other, bool) or isinstance(other, scat1D)
158
-
169
+ return self.backend.bk_complex(
170
+ x + self.backend.bk_real(y), x + self.backend.bk_imag(y)
171
+ )
172
+
173
+ def __add__(self, other):
174
+ assert (
175
+ isinstance(other, float)
176
+ or isinstance(other, np.float32)
177
+ or isinstance(other, int)
178
+ or isinstance(other, bool)
179
+ or isinstance(other, scat1D)
180
+ )
181
+
159
182
  if isinstance(other, scat1D):
160
- return scat1D(self.doadd(self.P00,other.P00), \
161
- self.doadd(self.S0, other.S0), \
162
- self.doadd(self.S1, other.S1), \
163
- self.doadd(self.S2, other.S2), \
164
- self.doadd(self.S2L, other.S2L), \
165
- self.j1,self.j2,backend=self.backend)
183
+ return scat1D(
184
+ self.doadd(self.P00, other.P00),
185
+ self.doadd(self.S0, other.S0),
186
+ self.doadd(self.S1, other.S1),
187
+ self.doadd(self.S2, other.S2),
188
+ self.doadd(self.S2L, other.S2L),
189
+ self.j1,
190
+ self.j2,
191
+ backend=self.backend,
192
+ )
166
193
  else:
167
- return scat1D((self.P00+ other), \
168
- (self.S0+ other), \
169
- (self.S1+ other), \
170
- (self.S2+ other), \
171
- (self.S2L+ other), \
172
- self.j1,self.j2,backend=self.backend)
173
-
174
- def toreal(self,value):
194
+ return scat1D(
195
+ (self.P00 + other),
196
+ (self.S0 + other),
197
+ (self.S1 + other),
198
+ (self.S2 + other),
199
+ (self.S2L + other),
200
+ self.j1,
201
+ self.j2,
202
+ backend=self.backend,
203
+ )
204
+
205
+ def toreal(self, value):
175
206
  if value is None:
176
207
  return None
177
-
208
+
178
209
  return self.backend.bk_real(value)
179
210
 
180
- def addcomplex(self,value,amp):
211
+ def addcomplex(self, value, amp):
181
212
  if value is None:
182
213
  return None
183
-
184
- return self.backend.bk_complex(value,amp*value)
185
-
186
- def add_complex(self,amp):
187
- return scat1D(self.addcomplex(self.P00,amp), \
188
- self.addcomplex(self.S0,amp), \
189
- self.addcomplex(self.S1,amp), \
190
- self.addcomplex(self.S2,amp), \
191
- self.addcomplex(self.S2L,amp), \
192
- self.j1,self.j2,backend=self.backend)
193
-
214
+
215
+ return self.backend.bk_complex(value, amp * value)
216
+
217
+ def add_complex(self, amp):
218
+ return scat1D(
219
+ self.addcomplex(self.P00, amp),
220
+ self.addcomplex(self.S0, amp),
221
+ self.addcomplex(self.S1, amp),
222
+ self.addcomplex(self.S2, amp),
223
+ self.addcomplex(self.S2L, amp),
224
+ self.j1,
225
+ self.j2,
226
+ backend=self.backend,
227
+ )
228
+
194
229
  def real(self):
195
- return scat1D(self.toreal(self.P00), \
196
- self.toreal(self.S0), \
197
- self.toreal(self.S1), \
198
- self.toreal(self.S2), \
199
- self.toreal(self.S2L), \
200
- self.j1,self.j2,backend=self.backend)
201
-
202
- def __radd__(self,other):
230
+ return scat1D(
231
+ self.toreal(self.P00),
232
+ self.toreal(self.S0),
233
+ self.toreal(self.S1),
234
+ self.toreal(self.S2),
235
+ self.toreal(self.S2L),
236
+ self.j1,
237
+ self.j2,
238
+ backend=self.backend,
239
+ )
240
+
241
+ def __radd__(self, other):
203
242
  return self.__add__(other)
204
243
 
205
- def __truediv__(self,other):
206
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
207
- isinstance(other, bool) or isinstance(other, scat1D)
208
-
244
+ def __truediv__(self, other):
245
+ assert (
246
+ isinstance(other, float)
247
+ or isinstance(other, np.float32)
248
+ or isinstance(other, int)
249
+ or isinstance(other, bool)
250
+ or isinstance(other, scat1D)
251
+ )
252
+
209
253
  if isinstance(other, scat1D):
210
- return scat1D(self.dodiv(self.P00, other.P00), \
211
- self.dodiv(self.S0, other.S0), \
212
- self.dodiv(self.S1, other.S1), \
213
- self.dodiv(self.S2, other.S2), \
214
- self.dodiv(self.S2L, other.S2L), \
215
- self.j1,self.j2,backend=self.backend)
254
+ return scat1D(
255
+ self.dodiv(self.P00, other.P00),
256
+ self.dodiv(self.S0, other.S0),
257
+ self.dodiv(self.S1, other.S1),
258
+ self.dodiv(self.S2, other.S2),
259
+ self.dodiv(self.S2L, other.S2L),
260
+ self.j1,
261
+ self.j2,
262
+ backend=self.backend,
263
+ )
216
264
  else:
217
- return scat1D((self.P00/ other), \
218
- (self.S0/ other), \
219
- (self.S1/ other), \
220
- (self.S2/ other), \
221
- (self.S2L/ other), \
222
- self.j1,self.j2,backend=self.backend)
223
-
224
-
225
- def __rtruediv__(self,other):
226
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
227
- isinstance(other, bool) or isinstance(other, scat1D)
228
-
265
+ return scat1D(
266
+ (self.P00 / other),
267
+ (self.S0 / other),
268
+ (self.S1 / other),
269
+ (self.S2 / other),
270
+ (self.S2L / other),
271
+ self.j1,
272
+ self.j2,
273
+ backend=self.backend,
274
+ )
275
+
276
+ def __rtruediv__(self, other):
277
+ assert (
278
+ isinstance(other, float)
279
+ or isinstance(other, np.float32)
280
+ or isinstance(other, int)
281
+ or isinstance(other, bool)
282
+ or isinstance(other, scat1D)
283
+ )
284
+
229
285
  if isinstance(other, scat1D):
230
- return scat1D(self.dodiv(other.P00, self.P00), \
231
- self.dodiv(other.S0 , self.S0), \
232
- self.dodiv(other.S1 , self.S1), \
233
- self.dodiv(other.S2 , self.S2), \
234
- self.dodiv(other.S2L , self.S2L), \
235
- self.j1,self.j2,backend=self.backend)
286
+ return scat1D(
287
+ self.dodiv(other.P00, self.P00),
288
+ self.dodiv(other.S0, self.S0),
289
+ self.dodiv(other.S1, self.S1),
290
+ self.dodiv(other.S2, self.S2),
291
+ self.dodiv(other.S2L, self.S2L),
292
+ self.j1,
293
+ self.j2,
294
+ backend=self.backend,
295
+ )
236
296
  else:
237
- return scat1D((other/ self.P00), \
238
- (other / self.S0), \
239
- (other / self.S1), \
240
- (other / self.S2), \
241
- (other / self.S2L), \
242
- self.j1,self.j2,backend=self.backend)
243
-
244
- def __sub__(self,other):
245
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
246
- isinstance(other, bool) or isinstance(other, scat1D)
247
-
297
+ return scat1D(
298
+ (other / self.P00),
299
+ (other / self.S0),
300
+ (other / self.S1),
301
+ (other / self.S2),
302
+ (other / self.S2L),
303
+ self.j1,
304
+ self.j2,
305
+ backend=self.backend,
306
+ )
307
+
308
+ def __sub__(self, other):
309
+ assert (
310
+ isinstance(other, float)
311
+ or isinstance(other, np.float32)
312
+ or isinstance(other, int)
313
+ or isinstance(other, bool)
314
+ or isinstance(other, scat1D)
315
+ )
316
+
248
317
  if isinstance(other, scat1D):
249
- return scat1D(self.domin(self.P00, other.P00), \
250
- self.domin(self.S0, other.S0), \
251
- self.domin(self.S1, other.S1), \
252
- self.domin(self.S2, other.S2), \
253
- self.domin(self.S2L, other.S2L), \
254
- self.j1,self.j2,backend=self.backend)
318
+ return scat1D(
319
+ self.domin(self.P00, other.P00),
320
+ self.domin(self.S0, other.S0),
321
+ self.domin(self.S1, other.S1),
322
+ self.domin(self.S2, other.S2),
323
+ self.domin(self.S2L, other.S2L),
324
+ self.j1,
325
+ self.j2,
326
+ backend=self.backend,
327
+ )
255
328
  else:
256
- return scat1D((self.P00- other), \
257
- (self.S0- other), \
258
- (self.S1- other), \
259
- (self.S2- other), \
260
- (self.S2L- other), \
261
- self.j1,self.j2,backend=self.backend)
262
-
263
- def __rsub__(self,other):
264
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
265
- isinstance(other, bool) or isinstance(other, scat1D)
266
-
329
+ return scat1D(
330
+ (self.P00 - other),
331
+ (self.S0 - other),
332
+ (self.S1 - other),
333
+ (self.S2 - other),
334
+ (self.S2L - other),
335
+ self.j1,
336
+ self.j2,
337
+ backend=self.backend,
338
+ )
339
+
340
+ def __rsub__(self, other):
341
+ assert (
342
+ isinstance(other, float)
343
+ or isinstance(other, np.float32)
344
+ or isinstance(other, int)
345
+ or isinstance(other, bool)
346
+ or isinstance(other, scat1D)
347
+ )
348
+
267
349
  if isinstance(other, scat1D):
268
- return scat1D(self.domin(other.P00,self.P00), \
269
- self.domin(other.S0, self.S0), \
270
- self.domin(other.S1, self.S1), \
271
- self.domin(other.S2, self.S2), \
272
- self.domin(other.S2L, self.S2L), \
273
- self.j1,self.j2,backend=self.backend)
350
+ return scat1D(
351
+ self.domin(other.P00, self.P00),
352
+ self.domin(other.S0, self.S0),
353
+ self.domin(other.S1, self.S1),
354
+ self.domin(other.S2, self.S2),
355
+ self.domin(other.S2L, self.S2L),
356
+ self.j1,
357
+ self.j2,
358
+ backend=self.backend,
359
+ )
274
360
  else:
275
- return scat1D((other-self.P00), \
276
- (other-self.S0), \
277
- (other-self.S1), \
278
- (other-self.S2), \
279
- (other-self.S2L), \
280
- self.j1,self.j2,backend=self.backend)
281
-
282
- def __mul__(self,other):
283
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
284
- isinstance(other, bool) or isinstance(other, scat1D)
285
-
361
+ return scat1D(
362
+ (other - self.P00),
363
+ (other - self.S0),
364
+ (other - self.S1),
365
+ (other - self.S2),
366
+ (other - self.S2L),
367
+ self.j1,
368
+ self.j2,
369
+ backend=self.backend,
370
+ )
371
+
372
+ def __mul__(self, other):
373
+ assert (
374
+ isinstance(other, float)
375
+ or isinstance(other, np.float32)
376
+ or isinstance(other, int)
377
+ or isinstance(other, bool)
378
+ or isinstance(other, scat1D)
379
+ )
380
+
286
381
  if isinstance(other, scat1D):
287
- return scat1D(self.domult(self.P00, other.P00), \
288
- self.domult(self.S0, other.S0), \
289
- self.domult(self.S1, other.S1), \
290
- self.domult(self.S2, other.S2), \
291
- self.domult(self.S2L, other.S2L), \
292
- self.j1,self.j2,backend=self.backend)
382
+ return scat1D(
383
+ self.domult(self.P00, other.P00),
384
+ self.domult(self.S0, other.S0),
385
+ self.domult(self.S1, other.S1),
386
+ self.domult(self.S2, other.S2),
387
+ self.domult(self.S2L, other.S2L),
388
+ self.j1,
389
+ self.j2,
390
+ backend=self.backend,
391
+ )
293
392
  else:
294
- return scat1D((self.P00* other), \
295
- (self.S0* other), \
296
- (self.S1* other), \
297
- (self.S2* other), \
298
- (self.S2L* other), \
299
- self.j1,self.j2,backend=self.backend)
393
+ return scat1D(
394
+ (self.P00 * other),
395
+ (self.S0 * other),
396
+ (self.S1 * other),
397
+ (self.S2 * other),
398
+ (self.S2L * other),
399
+ self.j1,
400
+ self.j2,
401
+ backend=self.backend,
402
+ )
403
+
300
404
  def relu(self):
301
- return scat1D(self.backend.bk_relu(self.P00),
302
- self.backend.bk_relu(self.S0),
303
- self.backend.bk_relu(self.S1),
304
- self.backend.bk_relu(self.S2),
305
- self.backend.bk_relu(self.S2L), \
306
- self.j1,self.j2,backend=self.backend)
307
-
308
-
309
- def __rmul__(self,other):
310
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
311
- isinstance(other, bool) or isinstance(other, scat1D)
312
-
405
+ return scat1D(
406
+ self.backend.bk_relu(self.P00),
407
+ self.backend.bk_relu(self.S0),
408
+ self.backend.bk_relu(self.S1),
409
+ self.backend.bk_relu(self.S2),
410
+ self.backend.bk_relu(self.S2L),
411
+ self.j1,
412
+ self.j2,
413
+ backend=self.backend,
414
+ )
415
+
416
+ def __rmul__(self, other):
417
+ assert (
418
+ isinstance(other, float)
419
+ or isinstance(other, np.float32)
420
+ or isinstance(other, int)
421
+ or isinstance(other, bool)
422
+ or isinstance(other, scat1D)
423
+ )
424
+
313
425
  if isinstance(other, scat1D):
314
- return scat1D(self.domult(self.P00, other.P00), \
315
- self.domult(self.S0, other.S0), \
316
- self.domult(self.S1, other.S1), \
317
- self.domult(self.S2, other.S2), \
318
- self.domult(self.S2L, other.S2L), \
319
- self.j1,self.j2,backend=self.backend)
426
+ return scat1D(
427
+ self.domult(self.P00, other.P00),
428
+ self.domult(self.S0, other.S0),
429
+ self.domult(self.S1, other.S1),
430
+ self.domult(self.S2, other.S2),
431
+ self.domult(self.S2L, other.S2L),
432
+ self.j1,
433
+ self.j2,
434
+ backend=self.backend,
435
+ )
320
436
  else:
321
- return scat1D((self.P00* other), \
322
- (self.S0* other), \
323
- (self.S1* other), \
324
- (self.S2* other), \
325
- (self.S2L* other), \
326
- self.j1,self.j2,backend=self.backend)
327
-
328
- def l1_abs(self,x):
329
- y=self.get_np(x)
437
+ return scat1D(
438
+ (self.P00 * other),
439
+ (self.S0 * other),
440
+ (self.S1 * other),
441
+ (self.S2 * other),
442
+ (self.S2L * other),
443
+ self.j1,
444
+ self.j2,
445
+ backend=self.backend,
446
+ )
447
+
448
+ def l1_abs(self, x):
449
+ y = self.get_np(x)
330
450
  if self.backend.bk_is_complex(y):
331
- tmp=y.real*y.real+y.imag*y.imag
332
- tmp=np.sign(tmp)*np.sqrt(np.fabs(tmp))
333
- y=tmp
334
-
335
- return(y)
336
-
337
- def plot(self,name=None,hold=True,color='blue',lw=1,legend=True):
451
+ tmp = y.real * y.real + y.imag * y.imag
452
+ tmp = np.sign(tmp) * np.sqrt(np.fabs(tmp))
453
+ y = tmp
454
+
455
+ return y
456
+
457
+ def plot(self, name=None, hold=True, color="blue", lw=1, legend=True):
338
458
 
339
459
  import matplotlib.pyplot as plt
340
460
 
341
- j1,j2=self.get_j_idx()
342
-
461
+ j1, j2 = self.get_j_idx()
462
+
343
463
  if name is None:
344
- name=''
464
+ name = ""
345
465
 
346
466
  if hold:
347
- plt.figure(figsize=(16,8))
348
-
349
- test=None
467
+ plt.figure(figsize=(16, 8))
468
+
469
+ test = None
350
470
  plt.subplot(2, 2, 1)
351
- tmp=abs(self.get_np(self.S1))
352
- if len(tmp.shape)==4:
471
+ tmp = abs(self.get_np(self.S1))
472
+ if len(tmp.shape) == 4:
353
473
  for i1 in range(tmp.shape[0]):
354
474
  for i2 in range(tmp.shape[1]):
355
475
  if test is None:
356
- test=1
357
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
476
+ test = 1
477
+ plt.plot(
478
+ tmp[i1, i2, :],
479
+ color=color,
480
+ label=r"%s $S_{1}$" % (name),
481
+ lw=lw,
482
+ )
358
483
  else:
359
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
484
+ plt.plot(tmp[i1, i2, :], color=color, lw=lw)
360
485
  else:
361
486
  for i1 in range(tmp.shape[0]):
362
487
  if test is None:
363
- test=1
364
- plt.plot(tmp[i1,:],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
488
+ test = 1
489
+ plt.plot(
490
+ tmp[i1, :], color=color, label=r"%s $S_{1}$" % (name), lw=lw
491
+ )
365
492
  else:
366
- plt.plot(tmp[i1,:],color=color, lw=lw)
367
- plt.yscale('log')
368
- plt.ylabel('S1')
369
- plt.xlabel(r'$j_{1}$')
493
+ plt.plot(tmp[i1, :], color=color, lw=lw)
494
+ plt.yscale("log")
495
+ plt.ylabel("S1")
496
+ plt.xlabel(r"$j_{1}$")
370
497
  plt.legend()
371
498
 
372
- test=None
499
+ test = None
373
500
  plt.subplot(2, 2, 2)
374
- tmp=abs(self.get_np(self.P00))
375
- if len(tmp.shape)==4:
501
+ tmp = abs(self.get_np(self.P00))
502
+ if len(tmp.shape) == 4:
376
503
  for i1 in range(tmp.shape[0]):
377
504
  for i2 in range(tmp.shape[0]):
378
505
  if test is None:
379
- test=1
380
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
506
+ test = 1
507
+ plt.plot(
508
+ tmp[i1, i2, :],
509
+ color=color,
510
+ label=r"%s $P_{00}$" % (name),
511
+ lw=lw,
512
+ )
381
513
  else:
382
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
514
+ plt.plot(tmp[i1, i2, :], color=color, lw=lw)
383
515
  else:
384
516
  for i1 in range(tmp.shape[0]):
385
517
  if test is None:
386
- test=1
387
- plt.plot(tmp[i1,:],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
518
+ test = 1
519
+ plt.plot(
520
+ tmp[i1, :], color=color, label=r"%s $P_{00}$" % (name), lw=lw
521
+ )
388
522
  else:
389
- plt.plot(tmp[i1,:],color=color, lw=lw)
390
- plt.yscale('log')
391
- plt.ylabel('P00')
392
- plt.xlabel(r'$j_{1}$')
523
+ plt.plot(tmp[i1, :], color=color, lw=lw)
524
+ plt.yscale("log")
525
+ plt.ylabel("P00")
526
+ plt.xlabel(r"$j_{1}$")
393
527
  plt.legend()
394
-
395
- ax1=plt.subplot(2, 2, 3)
528
+
529
+ ax1 = plt.subplot(2, 2, 3)
396
530
  ax2 = ax1.twiny()
397
- n=0
398
- tmp=abs(self.get_np(self.S2))
399
- lname=r'%s $S_{2}$' % (name)
400
- ax1.set_ylabel(r'$S_{2}$ [L1 norm]')
401
- test=None
402
- tabx=[]
403
- tabnx=[]
404
- tab2x=[]
405
- tab2nx=[]
406
- if len(tmp.shape)==5:
531
+ n = 0
532
+ tmp = abs(self.get_np(self.S2))
533
+ lname = r"%s $S_{2}$" % (name)
534
+ ax1.set_ylabel(r"$S_{2}$ [L1 norm]")
535
+ test = None
536
+ tabx = []
537
+ tabnx = []
538
+ tab2x = []
539
+ tab2nx = []
540
+ if len(tmp.shape) == 5:
407
541
  for i0 in range(tmp.shape[0]):
408
542
  for i1 in range(tmp.shape[1]):
409
- for i2 in range(j1.max()+1):
410
- if j2[j1==i2].shape[0]==1:
411
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2],'.', \
412
- color=color, lw=lw)
543
+ for i2 in range(j1.max() + 1):
544
+ if j2[j1 == i2].shape[0] == 1:
545
+ ax1.plot(
546
+ j2[j1 == i2] + n,
547
+ tmp[i0, i1, j1 == i2],
548
+ ".",
549
+ color=color,
550
+ lw=lw,
551
+ )
413
552
  else:
414
553
  if legend and test is None:
415
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
416
- color=color, label=lname, lw=lw)
417
- test=1
418
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
419
- color=color, lw=lw)
420
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
421
- tabx=tabx+[k+n for k in j2[j1==i2]]
422
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
423
- tab2nx=tab2nx+['%d'%(i2)]
424
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
425
- n=n+j2[j1==i2].shape[0]-1
554
+ ax1.plot(
555
+ j2[j1 == i2] + n,
556
+ tmp[i0, i1, j1 == i2],
557
+ color=color,
558
+ label=lname,
559
+ lw=lw,
560
+ )
561
+ test = 1
562
+ ax1.plot(
563
+ j2[j1 == i2] + n,
564
+ tmp[i0, i1, j1 == i2],
565
+ color=color,
566
+ lw=lw,
567
+ )
568
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
569
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
570
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
571
+ tab2nx = tab2nx + ["%d" % (i2)]
572
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
573
+ n = n + j2[j1 == i2].shape[0] - 1
426
574
  else:
427
575
  for i0 in range(tmp.shape[0]):
428
- for i2 in range(j1.max()+1):
429
- if j2[j1==i2].shape[0]==1:
430
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2],'.', \
431
- color=color, lw=lw)
576
+ for i2 in range(j1.max() + 1):
577
+ if j2[j1 == i2].shape[0] == 1:
578
+ ax1.plot(
579
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], ".", color=color, lw=lw
580
+ )
432
581
  else:
433
582
  if legend and test is None:
434
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
435
- color=color, label=lname, lw=lw)
436
- test=1
437
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
438
- color=color, lw=lw)
439
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
440
- tabx=tabx+[k+n for k in j2[j1==i2]]
441
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
442
- tab2nx=tab2nx+['%d'%(i2)]
443
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
444
- n=n+j2[j1==i2].shape[0]-1
445
- plt.yscale('log')
446
- ax1.set_xlim(0,n+2)
583
+ ax1.plot(
584
+ j2[j1 == i2] + n,
585
+ tmp[i0, j1 == i2],
586
+ color=color,
587
+ label=lname,
588
+ lw=lw,
589
+ )
590
+ test = 1
591
+ ax1.plot(
592
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], color=color, lw=lw
593
+ )
594
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
595
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
596
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
597
+ tab2nx = tab2nx + ["%d" % (i2)]
598
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
599
+ n = n + j2[j1 == i2].shape[0] - 1
600
+ plt.yscale("log")
601
+ ax1.set_xlim(0, n + 2)
447
602
  ax1.set_xticks(tabx)
448
- ax1.set_xticklabels(tabnx,fontsize=6)
449
- ax1.set_xlabel(r"$j_{2}$ ",fontsize=6)
450
-
603
+ ax1.set_xticklabels(tabnx, fontsize=6)
604
+ ax1.set_xlabel(r"$j_{2}$ ", fontsize=6)
605
+
451
606
  # Move twinned axis ticks and label from top to bottom
452
607
  ax2.xaxis.set_ticks_position("bottom")
453
608
  ax2.xaxis.set_label_position("bottom")
@@ -455,7 +610,7 @@ class scat1D:
455
610
  # Offset the twin axis below the host
456
611
  ax2.spines["bottom"].set_position(("axes", -0.15))
457
612
 
458
- # Turn on the frame for the twin axis, but then hide all
613
+ # Turn on the frame for the twin axis, but then hide all
459
614
  # but the bottom spine
460
615
  ax2.set_frame_on(True)
461
616
  ax2.patch.set_visible(False)
@@ -463,68 +618,91 @@ class scat1D:
463
618
  for sp in ax2.spines.values():
464
619
  sp.set_visible(False)
465
620
  ax2.spines["bottom"].set_visible(True)
466
- ax2.set_xlim(0,n+2)
621
+ ax2.set_xlim(0, n + 2)
467
622
  ax2.set_xticks(tab2x)
468
- ax2.set_xticklabels(tab2nx,fontsize=6)
469
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
623
+ ax2.set_xticklabels(tab2nx, fontsize=6)
624
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
470
625
  ax1.legend(frameon=0)
471
-
472
- ax1=plt.subplot(2, 2, 4)
626
+
627
+ ax1 = plt.subplot(2, 2, 4)
473
628
  ax2 = ax1.twiny()
474
- n=0
475
- tmp=abs(self.get_np(self.S2L))
476
- lname=r'%s $S2_{2}$' % (name)
477
- ax1.set_ylabel(r'$S_{2}$ [L2 norm]')
478
- test=None
479
- tabx=[]
480
- tabnx=[]
481
- tab2x=[]
482
- tab2nx=[]
483
- if len(tmp.shape)==5:
629
+ n = 0
630
+ tmp = abs(self.get_np(self.S2L))
631
+ lname = r"%s $S2_{2}$" % (name)
632
+ ax1.set_ylabel(r"$S_{2}$ [L2 norm]")
633
+ test = None
634
+ tabx = []
635
+ tabnx = []
636
+ tab2x = []
637
+ tab2nx = []
638
+ if len(tmp.shape) == 5:
484
639
  for i0 in range(tmp.shape[0]):
485
640
  for i1 in range(tmp.shape[1]):
486
- for i2 in range(j1.max()+1):
487
- if j2[j1==i2].shape[0]==1:
488
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2],'.', \
489
- color=color, lw=lw)
641
+ for i2 in range(j1.max() + 1):
642
+ if j2[j1 == i2].shape[0] == 1:
643
+ ax1.plot(
644
+ j2[j1 == i2] + n,
645
+ tmp[i0, i1, j1 == i2],
646
+ ".",
647
+ color=color,
648
+ lw=lw,
649
+ )
490
650
  else:
491
651
  if legend and test is None:
492
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
493
- color=color, label=lname, lw=lw)
494
- test=1
495
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
496
- color=color, lw=lw)
497
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
498
- tabx=tabx+[k+n for k in j2[j1==i2]]
499
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
500
- tab2nx=tab2nx+['%d'%(i2)]
501
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
502
- n=n+j2[j1==i2].shape[0]-1
652
+ ax1.plot(
653
+ j2[j1 == i2] + n,
654
+ tmp[i0, i1, j1 == i2],
655
+ color=color,
656
+ label=lname,
657
+ lw=lw,
658
+ )
659
+ test = 1
660
+ ax1.plot(
661
+ j2[j1 == i2] + n,
662
+ tmp[i0, i1, j1 == i2],
663
+ color=color,
664
+ lw=lw,
665
+ )
666
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
667
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
668
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
669
+ tab2nx = tab2nx + ["%d" % (i2)]
670
+ ax1.axvline(
671
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
672
+ )
673
+ n = n + j2[j1 == i2].shape[0] - 1
503
674
  else:
504
675
  for i0 in range(tmp.shape[0]):
505
- for i2 in range(j1.max()+1):
506
- if j2[j1==i2].shape[0]==1:
507
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2],'.', \
508
- color=color, lw=lw)
676
+ for i2 in range(j1.max() + 1):
677
+ if j2[j1 == i2].shape[0] == 1:
678
+ ax1.plot(
679
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], ".", color=color, lw=lw
680
+ )
509
681
  else:
510
682
  if legend and test is None:
511
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
512
- color=color, label=lname, lw=lw)
513
- test=1
514
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
515
- color=color, lw=lw)
516
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
517
- tabx=tabx+[k+n for k in j2[j1==i2]]
518
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
519
- tab2nx=tab2nx+['%d'%(i2)]
520
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
521
- n=n+j2[j1==i2].shape[0]-1
522
- plt.yscale('log')
523
- ax1.set_xlim(-1,n+3)
683
+ ax1.plot(
684
+ j2[j1 == i2] + n,
685
+ tmp[i0, j1 == i2],
686
+ color=color,
687
+ label=lname,
688
+ lw=lw,
689
+ )
690
+ test = 1
691
+ ax1.plot(
692
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], color=color, lw=lw
693
+ )
694
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
695
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
696
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
697
+ tab2nx = tab2nx + ["%d" % (i2)]
698
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
699
+ n = n + j2[j1 == i2].shape[0] - 1
700
+ plt.yscale("log")
701
+ ax1.set_xlim(-1, n + 3)
524
702
  ax1.set_xticks(tabx)
525
- ax1.set_xticklabels(tabnx,fontsize=6)
526
- ax1.set_xlabel(r"$j_{2}$",fontsize=6)
527
-
703
+ ax1.set_xticklabels(tabnx, fontsize=6)
704
+ ax1.set_xlabel(r"$j_{2}$", fontsize=6)
705
+
528
706
  # Move twinned axis ticks and label from top to bottom
529
707
  ax2.xaxis.set_ticks_position("bottom")
530
708
  ax2.xaxis.set_label_position("bottom")
@@ -532,7 +710,7 @@ class scat1D:
532
710
  # Offset the twin axis below the host
533
711
  ax2.spines["bottom"].set_position(("axes", -0.15))
534
712
 
535
- # Turn on the frame for the twin axis, but then hide all
713
+ # Turn on the frame for the twin axis, but then hide all
536
714
  # but the bottom spine
537
715
  ax2.set_frame_on(True)
538
716
  ax2.patch.set_visible(False)
@@ -540,125 +718,182 @@ class scat1D:
540
718
  for sp in ax2.spines.values():
541
719
  sp.set_visible(False)
542
720
  ax2.spines["bottom"].set_visible(True)
543
- ax2.set_xlim(0,n+3)
721
+ ax2.set_xlim(0, n + 3)
544
722
  ax2.set_xticks(tab2x)
545
- ax2.set_xticklabels(tab2nx,fontsize=6)
546
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
723
+ ax2.set_xticklabels(tab2nx, fontsize=6)
724
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
547
725
  ax1.legend(frameon=0)
548
-
549
- def save(self,filename):
550
- outlist=[self.get_S0().numpy(), \
551
- self.get_S1().numpy(), \
552
- self.get_S2().numpy(), \
553
- self.get_S2L().numpy(), \
554
- self.get_P00().numpy(), \
555
- self.j1, \
556
- self.j2]
557
-
558
- myout=open("%s.pkl"%(filename),"wb")
559
- pickle.dump(outlist,myout)
726
+
727
+ def save(self, filename):
728
+ outlist = [
729
+ self.get_S0().numpy(),
730
+ self.get_S1().numpy(),
731
+ self.get_S2().numpy(),
732
+ self.get_S2L().numpy(),
733
+ self.get_P00().numpy(),
734
+ self.j1,
735
+ self.j2,
736
+ ]
737
+
738
+ myout = open("%s.pkl" % (filename), "wb")
739
+ pickle.dump(outlist, myout)
560
740
  myout.close()
561
741
 
562
-
563
- def read(self,filename):
564
-
565
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
742
+ def read(self, filename):
743
+
744
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
745
+
746
+ return scat1D(
747
+ outlist[4],
748
+ outlist[0],
749
+ outlist[1],
750
+ outlist[2],
751
+ outlist[3],
752
+ outlist[5],
753
+ outlist[6],
754
+ backend=bk.foscat_backend("numpy"),
755
+ )
566
756
 
567
- return scat1D(outlist[4],outlist[0],outlist[1],outlist[2],outlist[3],outlist[5],outlist[6],backend=bk.foscat_backend('numpy'))
568
-
569
- def get_np(self,x):
757
+ def get_np(self, x):
570
758
  if isinstance(x, np.ndarray):
571
759
  return x
572
760
  else:
573
761
  return x.numpy()
574
762
 
575
763
  def std(self):
576
- return np.sqrt(((abs(self.get_np(self.S0)).std())**2+ \
577
- (abs(self.get_np(self.S1)).std())**2+ \
578
- (abs(self.get_np(self.S2)).std())**2+ \
579
- (abs(self.get_np(self.S2L)).std())**2+ \
580
- (abs(self.get_np(self.P00)).std())**2)/4)
764
+ return np.sqrt(
765
+ (
766
+ (abs(self.get_np(self.S0)).std()) ** 2
767
+ + (abs(self.get_np(self.S1)).std()) ** 2
768
+ + (abs(self.get_np(self.S2)).std()) ** 2
769
+ + (abs(self.get_np(self.S2L)).std()) ** 2
770
+ + (abs(self.get_np(self.P00)).std()) ** 2
771
+ )
772
+ / 4
773
+ )
581
774
 
582
775
  def mean(self):
583
- return abs(self.get_np(self.S0).mean()+ \
584
- self.get_np(self.S1).mean()+ \
585
- self.get_np(self.S2).mean()+ \
586
- self.get_np(self.S2L).mean()+ \
587
- self.get_np(self.P00).mean())/3
588
-
589
-
776
+ return (
777
+ abs(
778
+ self.get_np(self.S0).mean()
779
+ + self.get_np(self.S1).mean()
780
+ + self.get_np(self.S2).mean()
781
+ + self.get_np(self.S2L).mean()
782
+ + self.get_np(self.P00).mean()
783
+ )
784
+ / 3
785
+ )
590
786
 
591
787
  # ---------------------------------------------−---------
592
- def cleanval(self,x):
593
- x=x.numpy()
594
- x[np.isfinite(x)==False]=np.median(x[np.isfinite(x)])
788
+ def cleanval(self, x):
789
+ x = x.numpy()
790
+ x[~np.isfinite(x)] = np.median(x[np.isfinite(x)])
595
791
  return x
596
792
 
597
793
  def filter_inf(self):
598
- S1 = self.cleanval(self.S1)
599
- S0 = self.cleanval(self.S0)
794
+ S1 = self.cleanval(self.S1)
795
+ S0 = self.cleanval(self.S0)
600
796
  P00 = self.cleanval(self.P00)
601
- S2 = self.cleanval(self.S2)
797
+ S2 = self.cleanval(self.S2)
602
798
  S2L = self.cleanval(self.S2L)
603
799
 
604
- return scat1D(P00,S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
800
+ return scat1D(P00, S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
605
801
 
606
802
  # ---------------------------------------------−---------
607
- def interp(self,nscale,extend=False,constant=False,threshold=1E30):
608
-
609
- if nscale+2>self.S1.shape[1]:
610
- print('Can not *interp* %d with a statistic described over %d'%(nscale,self.S1.shape[1]))
611
- return scat1D(self.P00,self.S0,self.S1,self.S2,self.S2L,self.j1,self.j2,backend=self.backend)
612
- if isinstance(self.S1,np.ndarray):
613
- s1=self.S1
614
- p0=self.P00
615
- s2=self.S2
616
- s2l=self.S2L
803
+ def interp(self, nscale, extend=False, constant=False, threshold=1e30):
804
+
805
+ if nscale + 2 > self.S1.shape[1]:
806
+ print(
807
+ "Can not *interp* %d with a statistic described over %d"
808
+ % (nscale, self.S1.shape[1])
809
+ )
810
+ return scat1D(
811
+ self.P00,
812
+ self.S0,
813
+ self.S1,
814
+ self.S2,
815
+ self.S2L,
816
+ self.j1,
817
+ self.j2,
818
+ backend=self.backend,
819
+ )
820
+ if isinstance(self.S1, np.ndarray):
821
+ s1 = self.S1
822
+ p0 = self.P00
823
+ s2 = self.S2
824
+ s2l = self.S2L
617
825
  else:
618
- s1=self.S1.numpy()
619
- p0=self.P00.numpy()
620
- s2=self.S2.numpy()
621
- s2l=self.S2L.numpy()
622
-
623
- if isinstance(threshold,scat1D):
624
- if isinstance(threshold.S1,np.ndarray):
625
- s1th=threshold.S1
626
- p0th=threshold.P00
627
- s2th=threshold.S2
628
- s2lth=threshold.S2L
826
+ s1 = self.S1.numpy()
827
+ p0 = self.P00.numpy()
828
+ s2 = self.S2.numpy()
829
+ s2l = self.S2L.numpy()
830
+
831
+ if isinstance(threshold, scat1D):
832
+ if isinstance(threshold.S1, np.ndarray):
833
+ s1th = threshold.S1
834
+ p0th = threshold.P00
835
+ s2th = threshold.S2
836
+ s2lth = threshold.S2L
629
837
  else:
630
- s1th=threshold.S1.numpy()
631
- p0th=threshold.P00.numpy()
632
- s2th=threshold.S2.numpy()
633
- s2lth=threshold.S2L.numpy()
838
+ s1th = threshold.S1.numpy()
839
+ p0th = threshold.P00.numpy()
840
+ s2th = threshold.S2.numpy()
841
+ s2lth = threshold.S2L.numpy()
634
842
  else:
635
- s1th=threshold+0*s1
636
- p0th=threshold+0*p0
637
- s2th=threshold+0*s2
638
- s2lth=threshold+0*s2l
843
+ s1th = threshold + 0 * s1
844
+ p0th = threshold + 0 * p0
845
+ s2th = threshold + 0 * s2
846
+ s2lth = threshold + 0 * s2l
639
847
 
640
848
  for k in range(nscale):
641
849
  if constant:
642
- s1[:,nscale-1-k,:]=s1[:,nscale-k,:]
643
- p0[:,nscale-1-k,:]=p0[:,nscale-k,:]
850
+ s1[:, nscale - 1 - k, :] = s1[:, nscale - k, :]
851
+ p0[:, nscale - 1 - k, :] = p0[:, nscale - k, :]
644
852
  else:
645
- idx=np.where((s1[:,nscale+1-k,:]>0)*(s1[:,nscale+2-k,:]>0)*(s1[:,nscale-k,:]<s1th[:,nscale-k,:]))
646
- if len(idx[0])>0:
647
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(s1[idx[0],nscale+1-k,idx[1]])-2*np.log(s1[idx[0],nscale+2-k,idx[1]]))
648
- idx=np.where((s1[:,nscale-k,:]>0)*(s1[:,nscale+1-k,:]>0)*(s1[:,nscale-1-k,:]<s1th[:,nscale-1-k,:]))
649
- if len(idx[0])>0:
650
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(s1[idx[0],nscale-k,idx[1]])-np.log(s1[idx[0],nscale+1-k,idx[1]]))
651
-
652
- idx=np.where((p0[:,nscale+1-k,:]>0)*(p0[:,nscale+2-k,:]>0)*(p0[:,nscale-k,:]<p0th[:,nscale-k,:]))
653
- if len(idx[0])>0:
654
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(p0[idx[0],nscale+1-k,idx[1]])-2*np.log(p0[idx[0],nscale+2-k,idx[1]]))
655
-
656
- idx=np.where((p0[:,nscale-k,:]>0)*(p0[:,nscale+1-k,:]>0)*(p0[:,nscale-1-k,:]<p0th[:,nscale-1-k,:]))
657
- if len(idx[0])>0:
658
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(p0[idx[0],nscale-k,idx[1]])-np.log(p0[idx[0],nscale+1-k,idx[1]]))
659
-
660
-
661
- j1,j2=self.get_j_idx()
853
+ idx = np.where(
854
+ (s1[:, nscale + 1 - k, :] > 0)
855
+ * (s1[:, nscale + 2 - k, :] > 0)
856
+ * (s1[:, nscale - k, :] < s1th[:, nscale - k, :])
857
+ )
858
+ if len(idx[0]) > 0:
859
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
860
+ 3 * np.log(s1[idx[0], nscale + 1 - k, idx[1]])
861
+ - 2 * np.log(s1[idx[0], nscale + 2 - k, idx[1]])
862
+ )
863
+ idx = np.where(
864
+ (s1[:, nscale - k, :] > 0)
865
+ * (s1[:, nscale + 1 - k, :] > 0)
866
+ * (s1[:, nscale - 1 - k, :] < s1th[:, nscale - 1 - k, :])
867
+ )
868
+ if len(idx[0]) > 0:
869
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
870
+ 2 * np.log(s1[idx[0], nscale - k, idx[1]])
871
+ - np.log(s1[idx[0], nscale + 1 - k, idx[1]])
872
+ )
873
+
874
+ idx = np.where(
875
+ (p0[:, nscale + 1 - k, :] > 0)
876
+ * (p0[:, nscale + 2 - k, :] > 0)
877
+ * (p0[:, nscale - k, :] < p0th[:, nscale - k, :])
878
+ )
879
+ if len(idx[0]) > 0:
880
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
881
+ 3 * np.log(p0[idx[0], nscale + 1 - k, idx[1]])
882
+ - 2 * np.log(p0[idx[0], nscale + 2 - k, idx[1]])
883
+ )
884
+
885
+ idx = np.where(
886
+ (p0[:, nscale - k, :] > 0)
887
+ * (p0[:, nscale + 1 - k, :] > 0)
888
+ * (p0[:, nscale - 1 - k, :] < p0th[:, nscale - 1 - k, :])
889
+ )
890
+ if len(idx[0]) > 0:
891
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
892
+ 2 * np.log(p0[idx[0], nscale - k, idx[1]])
893
+ - np.log(p0[idx[0], nscale + 1 - k, idx[1]])
894
+ )
895
+
896
+ j1, j2 = self.get_j_idx()
662
897
 
663
898
  for k in range(nscale):
664
899
 
@@ -671,460 +906,641 @@ class scat1D:
671
906
  s2l[:,i0]=np.exp(2*np.log(s2l[:,i1])-np.log(s2l[:,i2]))
672
907
  """
673
908
 
674
- for l in range(nscale-k):
675
- i0=np.where((j1==nscale-1-k-l)*(j2==nscale-1-k))[0]
676
- i1=np.where((j1==nscale-1-k-l)*(j2==nscale -k))[0]
677
- i2=np.where((j1==nscale-1-k-l)*(j2==nscale+1-k))[0]
678
- i3=np.where((j1==nscale-1-k-l)*(j2==nscale+2-k))[0]
679
-
909
+ for l_orient in range(nscale - k):
910
+ i0 = np.where(
911
+ (j1 == nscale - 1 - k - l_orient) * (j2 == nscale - 1 - k)
912
+ )[0]
913
+ i1 = np.where((j1 == nscale - 1 - k - l_orient) * (j2 == nscale - k))[0]
914
+ i2 = np.where(
915
+ (j1 == nscale - 1 - k - l_orient) * (j2 == nscale + 1 - k)
916
+ )[0]
917
+ i3 = np.where(
918
+ (j1 == nscale - 1 - k - l_orient) * (j2 == nscale + 2 - k)
919
+ )[0]
920
+
680
921
  if constant:
681
- s2[:,i0]=s2[:,i1]
682
- s2l[:,i0]=s2l[:,i1]
922
+ s2[:, i0] = s2[:, i1]
923
+ s2l[:, i0] = s2l[:, i1]
683
924
  else:
684
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
685
- if len(idx[0])>0:
686
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
687
-
688
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
689
- if len(idx[0])>0:
690
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
691
-
692
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
693
- if len(idx[0])>0:
694
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
695
-
696
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
697
- if len(idx[0])>0:
698
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
699
-
925
+ idx = np.where(
926
+ (s2[:, i2] > 0) * (s2[:, i3] > 0) * (s2[:, i2] < s2th[:, i2])
927
+ )
928
+ if len(idx[0]) > 0:
929
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
930
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
931
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
932
+ )
933
+
934
+ idx = np.where(
935
+ (s2[:, i1] > 0) * (s2[:, i2] > 0) * (s2[:, i1] < s2th[:, i1])
936
+ )
937
+ if len(idx[0]) > 0:
938
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
939
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
940
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
941
+ )
942
+
943
+ idx = np.where(
944
+ (s2l[:, i2] > 0)
945
+ * (s2l[:, i3] > 0)
946
+ * (s2l[:, i2] < s2lth[:, i2])
947
+ )
948
+ if len(idx[0]) > 0:
949
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
950
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
951
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
952
+ )
953
+
954
+ idx = np.where(
955
+ (s2l[:, i1] > 0)
956
+ * (s2l[:, i2] > 0)
957
+ * (s2l[:, i1] < s2lth[:, i1])
958
+ )
959
+ if len(idx[0]) > 0:
960
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
961
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
962
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
963
+ )
964
+
700
965
  if extend:
701
966
  for k in range(nscale):
702
- for l in range(1,nscale):
703
- i0=np.where((j1==2*nscale-1-k)*(j2==2*nscale-1-k-l))[0]
704
- i1=np.where((j1==2*nscale-1-k)*(j2==2*nscale -k-l))[0]
705
- i2=np.where((j1==2*nscale-1-k)*(j2==2*nscale+1-k-l))[0]
706
- i3=np.where((j1==2*nscale-1-k)*(j2==2*nscale+2-k-l))[0]
967
+ for l_orient in range(1, nscale):
968
+ i0 = np.where(
969
+ (j1 == 2 * nscale - 1 - k)
970
+ * (j2 == 2 * nscale - 1 - k - l_orient)
971
+ )[0]
972
+ i1 = np.where(
973
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale - k - l_orient)
974
+ )[0]
975
+ i2 = np.where(
976
+ (j1 == 2 * nscale - 1 - k)
977
+ * (j2 == 2 * nscale + 1 - k - l_orient)
978
+ )[0]
979
+ i3 = np.where(
980
+ (j1 == 2 * nscale - 1 - k)
981
+ * (j2 == 2 * nscale + 2 - k - l_orient)
982
+ )[0]
707
983
  if constant:
708
- s2[:,i0]=s2[:,i1]
709
- s2l[:,i0]=s2l[:,i1]
984
+ s2[:, i0] = s2[:, i1]
985
+ s2l[:, i0] = s2l[:, i1]
710
986
  else:
711
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
712
- if len(idx[0])>0:
713
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
714
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
715
- if len(idx[0])>0:
716
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
717
-
718
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
719
- if len(idx[0])>0:
720
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
721
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
722
- if len(idx[0])>0:
723
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
724
-
725
- s1[np.isnan(s1)]=0.0
726
- p0[np.isnan(p0)]=0.0
727
- s2[np.isnan(s2)]=0.0
728
- s2l[np.isnan(s2l)]=0.0
729
-
730
- return scat1D(self.backend.constant(p0),self.S0,
731
- self.backend.constant(s1),
732
- self.backend.constant(s2),
733
- self.backend.constant(s2l),self.j1,self.j2,backend=self.backend)
987
+ idx = np.where(
988
+ (s2[:, i2] > 0)
989
+ * (s2[:, i3] > 0)
990
+ * (s2[:, i2] < s2th[:, i2])
991
+ )
992
+ if len(idx[0]) > 0:
993
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
994
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
995
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
996
+ )
997
+ idx = np.where(
998
+ (s2[:, i1] > 0)
999
+ * (s2[:, i2] > 0)
1000
+ * (s2[:, i1] < s2th[:, i1])
1001
+ )
1002
+ if len(idx[0]) > 0:
1003
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1004
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1005
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1006
+ )
1007
+
1008
+ idx = np.where(
1009
+ (s2l[:, i2] > 0)
1010
+ * (s2l[:, i3] > 0)
1011
+ * (s2l[:, i2] < s2lth[:, i2])
1012
+ )
1013
+ if len(idx[0]) > 0:
1014
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1015
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1016
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1017
+ )
1018
+ idx = np.where(
1019
+ (s2l[:, i1] > 0)
1020
+ * (s2l[:, i2] > 0)
1021
+ * (s2l[:, i1] < s2lth[:, i1])
1022
+ )
1023
+ if len(idx[0]) > 0:
1024
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1025
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1026
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1027
+ )
1028
+
1029
+ s1[np.isnan(s1)] = 0.0
1030
+ p0[np.isnan(p0)] = 0.0
1031
+ s2[np.isnan(s2)] = 0.0
1032
+ s2l[np.isnan(s2l)] = 0.0
1033
+
1034
+ return scat1D(
1035
+ self.backend.constant(p0),
1036
+ self.S0,
1037
+ self.backend.constant(s1),
1038
+ self.backend.constant(s2),
1039
+ self.backend.constant(s2l),
1040
+ self.j1,
1041
+ self.j2,
1042
+ backend=self.backend,
1043
+ )
734
1044
 
735
1045
  # ---------------------------------------------−---------
736
- def model(self,i__y,add=0,dx=3,dell=2,weigth=None,inverse=False):
1046
+ def model(self, i__y, add=0, dx=3, dell=2, weigth=None, inverse=False):
737
1047
 
738
- if i__y.shape[0]<dx+1:
739
- l__dx=i__y.shape[0]-1
1048
+ if i__y.shape[0] < dx + 1:
1049
+ l__dx = i__y.shape[0] - 1
740
1050
  else:
741
- l__dx=dx
1051
+ l__dx = dx
742
1052
 
743
- if i__y.shape[0]<dell:
744
- l__dell=0
1053
+ if i__y.shape[0] < dell:
1054
+ l__dell = 0
745
1055
  else:
746
- l__dell=dell
1056
+ l__dell = dell
747
1057
 
748
- if l__dx<2:
749
- res=np.zeros([i__y.shape[0]+add])
1058
+ if l__dx < 2:
1059
+ res = np.zeros([i__y.shape[0] + add])
750
1060
  if inverse:
751
- res[:-add]=i__y
1061
+ res[:-add] = i__y
752
1062
  else:
753
- res[add:]=i__y[0:]
1063
+ res[add:] = i__y[0:]
754
1064
  return res
755
1065
 
756
1066
  if weigth is None:
757
- w=2**(np.arange(l__dx))
1067
+ w = 2 ** (np.arange(l__dx))
758
1068
  else:
759
1069
  if not inverse:
760
- w=weigth[0:l__dx]
1070
+ w = weigth[0:l__dx]
761
1071
  else:
762
- w=weigth[-l__dx:]
1072
+ w = weigth[-l__dx:]
763
1073
 
764
- x=np.arange(l__dx)+1
1074
+ x = np.arange(l__dx) + 1
765
1075
  if not inverse:
766
- y=np.log(i__y[1:l__dx+1])
1076
+ y = np.log(i__y[1 : l__dx + 1])
767
1077
  else:
768
- y=np.log(i__y[-(l__dx+1):-1])
1078
+ y = np.log(i__y[-(l__dx + 1) : -1])
769
1079
 
770
- r=np.polyfit(x,y,1,w=w)
1080
+ r = np.polyfit(x, y, 1, w=w)
771
1081
 
772
1082
  if inverse:
773
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-1)+r[1])
774
- res[:-(l__dell+add)]=i__y[:-l__dell]
1083
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - 1) + r[1])
1084
+ res[: -(l__dell + add)] = i__y[:-l__dell]
775
1085
  else:
776
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-add)+r[1])
777
- res[l__dell+add:]=i__y[l__dell:]
1086
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - add) + r[1])
1087
+ res[l__dell + add :] = i__y[l__dell:]
778
1088
  return res
779
1089
 
780
- def findn(self,n):
781
- d=np.sqrt(1+8*n)
782
- return int((d-1)/2)
1090
+ def findn(self, n):
1091
+ d = np.sqrt(1 + 8 * n)
1092
+ return int((d - 1) / 2)
783
1093
 
784
- def findidx(self,s2):
785
- i1=np.zeros([s2.shape[1]],dtype='int')
786
- i2=np.zeros([s2.shape[1]],dtype='int')
787
- n=0
1094
+ def findidx(self, s2):
1095
+ i1 = np.zeros([s2.shape[1]], dtype="int")
1096
+ i2 = np.zeros([s2.shape[1]], dtype="int")
1097
+ n = 0
788
1098
  for k in range(self.findn(s2.shape[1])):
789
- i1[n:n+k+1]=np.arange(k+1)
790
- i2[n:n+k+1]=k
791
- n=n+k+1
792
- return i1,i2
793
-
794
- def extrapol_s2(self,add,lnorm=1):
795
- if lnorm==1:
796
- s2=self.S2.numpy()
797
- if lnorm==2:
798
- s2=self.S2L.numpy()
799
- i1,i2=self.findidx(s2)
800
-
801
- so2=np.zeros([s2.shape[0],(self.findn(s2.shape[1])+add)*(self.findn(s2.shape[1])+add+1)//2,s2.shape[2],s2.shape[3]])
802
- oi1,oi2=self.findidx(so2)
803
- for l in range(s2.shape[0]):
1099
+ i1[n : n + k + 1] = np.arange(k + 1)
1100
+ i2[n : n + k + 1] = k
1101
+ n = n + k + 1
1102
+ return i1, i2
1103
+
1104
+ def extrapol_s2(self, add, lnorm=1):
1105
+ if lnorm == 1:
1106
+ s2 = self.S2.numpy()
1107
+ if lnorm == 2:
1108
+ s2 = self.S2L.numpy()
1109
+ i1, i2 = self.findidx(s2)
1110
+
1111
+ so2 = np.zeros(
1112
+ [
1113
+ s2.shape[0],
1114
+ (self.findn(s2.shape[1]) + add)
1115
+ * (self.findn(s2.shape[1]) + add + 1)
1116
+ // 2,
1117
+ s2.shape[2],
1118
+ s2.shape[3],
1119
+ ]
1120
+ )
1121
+ oi1, oi2 = self.findidx(so2)
1122
+ for l_orient in range(s2.shape[0]):
804
1123
  for k in range(self.findn(s2.shape[1])):
805
1124
  for i in range(s2.shape[2]):
806
1125
  for j in range(s2.shape[3]):
807
- tmp=self.model(s2[l,i2==k,i,j],dx=4,dell=1,add=add,weigth=np.array([1,2,2,2]))
808
- tmp[np.isnan(tmp)]=0.0
809
- so2[l,oi2==k+add,i,j]=tmp
810
-
811
-
812
- for l in range(s2.shape[0]):
813
- for k in range(add+1,-1,-1):
814
- lidx=np.where(oi2-oi1==k)[0]
815
- lidx2=np.where(oi2-oi1==k+1)[0]
1126
+ tmp = self.model(
1127
+ s2[l_orient, i2 == k, i, j],
1128
+ dx=4,
1129
+ dell=1,
1130
+ add=add,
1131
+ weigth=np.array([1, 2, 2, 2]),
1132
+ )
1133
+ tmp[np.isnan(tmp)] = 0.0
1134
+ so2[l_orient, oi2 == k + add, i, j] = tmp
1135
+
1136
+ for l_orient in range(s2.shape[0]):
1137
+ for k in range(add + 1, -1, -1):
1138
+ lidx = np.where(oi2 - oi1 == k)[0]
1139
+ lidx2 = np.where(oi2 - oi1 == k + 1)[0]
816
1140
  for i in range(s2.shape[2]):
817
1141
  for j in range(s2.shape[3]):
818
- so2[l,lidx[0:add+2-k],i,j]=so2[l,lidx2[0:add+2-k],i,j]
1142
+ so2[l_orient, lidx[0 : add + 2 - k], i, j] = so2[
1143
+ l_orient, lidx2[0 : add + 2 - k], i, j
1144
+ ]
819
1145
 
820
- return(so2)
1146
+ return so2
821
1147
 
822
- def extrapol_s1(self,i_s1,add):
823
- s1=i_s1.numpy()
824
- so1=np.zeros([s1.shape[0],s1.shape[1]+add,s1.shape[2]])
1148
+ def extrapol_s1(self, i_s1, add):
1149
+ s1 = i_s1.numpy()
1150
+ so1 = np.zeros([s1.shape[0], s1.shape[1] + add, s1.shape[2]])
825
1151
  for k in range(s1.shape[0]):
826
1152
  for i in range(s1.shape[2]):
827
- so1[k,:,i]=self.model(s1[k,:,i],dx=4,dell=1,add=add)
828
- so1[k,np.isnan(so1[k,:,i]),i]=0.0
1153
+ so1[k, :, i] = self.model(s1[k, :, i], dx=4, dell=1, add=add)
1154
+ so1[k, np.isnan(so1[k, :, i]), i] = 0.0
829
1155
  return so1
830
1156
 
831
- def extrapol(self,add):
832
- return scat1D(self.extrapol_s1(self.P00,add), \
833
- self.S0, \
834
- self.extrapol_s1(self.S1,add), \
835
- self.extrapol_s2(add,lnorm=1), \
836
- self.extrapol_s2(add,lnorm=2),self.j1,self.j2,backend=self.backend)
837
-
838
-
839
-
840
-
841
-
1157
+ def extrapol(self, add):
1158
+ return scat1D(
1159
+ self.extrapol_s1(self.P00, add),
1160
+ self.S0,
1161
+ self.extrapol_s1(self.S1, add),
1162
+ self.extrapol_s2(add, lnorm=1),
1163
+ self.extrapol_s2(add, lnorm=2),
1164
+ self.j1,
1165
+ self.j2,
1166
+ backend=self.backend,
1167
+ )
1168
+
1169
+
842
1170
  class funct(FOC.FoCUS):
843
1171
 
844
- def fill(self,im,nullval=0):
845
- return self.fill_1d(im,nullval=nullval)
846
-
847
- def ud_grade(self,im,nout,axis=0):
848
- return self.ud_grade_1d(im,nout,axis=axis)
849
-
850
- def up_grade(self,im,nout,axis=0):
851
- return self.up_grade_1d(im,nout,axis=axis)
852
-
853
- def smooth(self,data,axis=0):
854
- return self.smooth_1d(data,axis=axis)
855
-
856
- def convol(self,data,axis=0):
857
- return self.convol_1d(data,axis=axis)
858
-
859
- def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,Add_R45=False,axis=0):
1172
+ def fill(self, im, nullval=0):
1173
+ return self.fill_1d(im, nullval=nullval)
1174
+
1175
+ def ud_grade(self, im, nout, axis=0):
1176
+ return self.ud_grade_1d(im, nout, axis=axis)
1177
+
1178
+ def up_grade(self, im, nout, axis=0):
1179
+ return self.up_grade_1d(im, nout, axis=axis)
1180
+
1181
+ def smooth(self, data, axis=0):
1182
+ return self.smooth_1d(data, axis=axis)
1183
+
1184
+ def convol(self, data, axis=0):
1185
+ return self.convol_1d(data, axis=axis)
1186
+
1187
+ def eval(
1188
+ self,
1189
+ image1,
1190
+ image2=None,
1191
+ mask=None,
1192
+ Auto=True,
1193
+ s0_off=1e-6,
1194
+ Add_R45=False,
1195
+ axis=0,
1196
+ ):
860
1197
  # Check input consistency
861
1198
  if mask is not None:
862
- if len(image1.shape)==1:
863
- if image1.shape[0]!=mask.shape[1]:
864
- print('The mask should have the same size than the input timeline to eval Scattering')
1199
+ if len(image1.shape) == 1:
1200
+ if image1.shape[0] != mask.shape[1]:
1201
+ print(
1202
+ "The mask should have the same size than the input timeline to eval Scattering"
1203
+ )
865
1204
  return None
866
1205
  else:
867
- if image1.shape[1]!=mask.shape[1]:
868
- print('The mask should have the same size than the input timeline to eval Scattering')
1206
+ if image1.shape[1] != mask.shape[1]:
1207
+ print(
1208
+ "The mask should have the same size than the input timeline to eval Scattering"
1209
+ )
869
1210
  return None
870
-
1211
+
871
1212
  ### AUTO OR CROSS
872
1213
  cross = False
873
1214
  if image2 is not None:
874
1215
  cross = True
875
- all_cross=not Auto
876
- else:
877
- all_cross=False
878
-
1216
+
879
1217
  # determine jmax and nside corresponding to the input map
880
1218
  im_shape = image1.shape
881
1219
 
882
- nside=im_shape[len(image1.shape)-1]
883
- npix=nside
884
-
885
- jmax=int(np.log(nside)/np.log(2)) #-self.OSTEP
1220
+ nside = im_shape[len(image1.shape) - 1]
1221
+ npix = nside
1222
+
1223
+ jmax = int(np.log(nside) / np.log(2)) # -self.OSTEP
886
1224
 
887
1225
  ### LOCAL VARIABLES (IMAGES and MASK)
888
1226
  # Check if image1 is [Npix] or [Nbatch,Npix]
889
- if len(image1.shape)==1:
1227
+ if len(image1.shape) == 1:
890
1228
  # image1 is [Nbatch, Npix]
891
- I1 = self.backend.bk_cast(self.backend.bk_expand_dims(image1,0)) # Local image1 [Nbatch, Npix]
1229
+ I1 = self.backend.bk_cast(
1230
+ self.backend.bk_expand_dims(image1, 0)
1231
+ ) # Local image1 [Nbatch, Npix]
892
1232
  if cross:
893
- I2 = self.backend.bk_cast(self.backend.bk_expand_dims(image2,0)) # Local image2 [Nbatch, Npix]
1233
+ I2 = self.backend.bk_cast(
1234
+ self.backend.bk_expand_dims(image2, 0)
1235
+ ) # Local image2 [Nbatch, Npix]
894
1236
  else:
895
- I1=self.backend.bk_cast(image1)
1237
+ I1 = self.backend.bk_cast(image1)
896
1238
  if cross:
897
- I2=self.backend.bk_cast(image2)
898
-
1239
+ I2 = self.backend.bk_cast(image2)
1240
+
899
1241
  # self.mask is [Nmask, Npix]
900
-
1242
+
901
1243
  if mask is None:
902
1244
  vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
903
1245
  else:
904
1246
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
905
1247
 
906
- if self.KERNELSZ>3:
1248
+ if self.KERNELSZ > 3:
907
1249
  # if the kernel size is bigger than 3 increase the binning before smoothing
908
- l_image1=self.up_grade_1d(I1,nside*2,axis=axis+1)
909
- vmask=self.up_grade_1d(vmask,nside*2,axis=1)
910
-
1250
+ l_image1 = self.up_grade_1d(I1, nside * 2, axis=axis + 1)
1251
+ vmask = self.up_grade_1d(vmask, nside * 2, axis=1)
1252
+
911
1253
  if cross:
912
- l_image2=self.up_grade_1d(I2,nside*2,axis=axis+1)
1254
+ l_image2 = self.up_grade_1d(I2, nside * 2, axis=axis + 1)
913
1255
  else:
914
- l_image1=I1
1256
+ l_image1 = I1
915
1257
  if cross:
916
- l_image2=I2
917
- if len(image1.shape)==1:
918
- s0 = self.backend.bk_reduce_sum(l_image1*vmask,axis=axis+1)
919
- if cross and Auto==False:
920
- s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(l_image2*vmask,axis=axis)])
1258
+ l_image2 = I2
1259
+ if len(image1.shape) == 1:
1260
+ s0 = self.backend.bk_reduce_sum(l_image1 * vmask, axis=axis + 1)
1261
+ if cross and not Auto:
1262
+ s0 = self.backend.bk_concat(
1263
+ [s0, self.backend.bk_reduce_sum(l_image2 * vmask, axis=axis)]
1264
+ )
921
1265
  else:
922
- lmask=self.backend.bk_expand_dims(vmask,0)
923
- lim=self.backend.bk_expand_dims(l_image1,1)
924
- s0 = self.backend.bk_reduce_sum(lim*lmask,axis=axis+2)
925
- if cross and Auto==False:
926
- lim=self.backend.bk_expand_dims(l_image2,1)
927
- s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(lim*lmask,axis=axis+2)])
928
-
929
-
930
-
931
- s1=None
932
- s2=None
933
- s2l=None
934
- p00=None
935
- s2j1=None
936
- s2j2=None
937
-
938
- l2_image=None
939
- l2_image_imag=None
1266
+ lmask = self.backend.bk_expand_dims(vmask, 0)
1267
+ lim = self.backend.bk_expand_dims(l_image1, 1)
1268
+ s0 = self.backend.bk_reduce_sum(lim * lmask, axis=axis + 2)
1269
+ if cross and not Auto:
1270
+ lim = self.backend.bk_expand_dims(l_image2, 1)
1271
+ s0 = self.backend.bk_concat(
1272
+ [s0, self.backend.bk_reduce_sum(lim * lmask, axis=axis + 2)]
1273
+ )
1274
+
1275
+ s1 = None
1276
+ s2 = None
1277
+ s2l = None
1278
+ p00 = None
1279
+ s2j1 = None
1280
+ s2j2 = None
1281
+
1282
+ l2_image = None
940
1283
 
941
1284
  for j1 in range(jmax):
942
- if j1<jmax-self.OSTEP: # stop to add scales
1285
+ if j1 < jmax - self.OSTEP: # stop to add scales
943
1286
  # Convol image along the axis defined by 'axis' using the wavelet defined at
944
1287
  # the foscat initialisation
945
- #c_image_real is [....,Npix_j1,....,Norient]
946
- c_image1=self.convol_1d(l_image1,axis=axis+1)
1288
+ # c_image_real is [....,Npix_j1,....,Norient]
1289
+ c_image1 = self.convol_1d(l_image1, axis=axis + 1)
947
1290
  if cross:
948
- c_image2=self.convol_1d(l_image2,axis=axis+1)
1291
+ c_image2 = self.convol_1d(l_image2, axis=axis + 1)
949
1292
  else:
950
- c_image2=c_image1
1293
+ c_image2 = c_image1
951
1294
 
952
1295
  # Compute (a+ib)*(a+ib)* the last c_image column is the real and imaginary part
953
- conj=c_image1*self.backend.bk_conjugate(c_image2)
954
-
1296
+ conj = c_image1 * self.backend.bk_conjugate(c_image2)
1297
+
955
1298
  if Auto:
956
- conj=self.backend.bk_real(conj)
1299
+ conj = self.backend.bk_real(conj)
957
1300
 
958
1301
  # Compute l_p00 [....,....,Nmask,1]
959
- l_p00 = self.backend.bk_expand_dims(self.backend.bk_reduce_sum(conj*vmask,axis=1),-1)
1302
+ l_p00 = self.backend.bk_expand_dims(
1303
+ self.backend.bk_reduce_sum(conj * vmask, axis=1), -1
1304
+ )
960
1305
 
961
- conj=self.backend.bk_L1(conj)
1306
+ conj = self.backend.bk_L1(conj)
962
1307
 
963
1308
  # Compute l_s1 [....,....,Nmask,1]
964
- l_s1 = self.backend.bk_expand_dims(self.backend.bk_reduce_sum(conj*vmask,axis=1),-1)
965
-
966
- # Concat S1,P00 [....,....,Nmask,j1]
1309
+ l_s1 = self.backend.bk_expand_dims(
1310
+ self.backend.bk_reduce_sum(conj * vmask, axis=1), -1
1311
+ )
1312
+
1313
+ # Concat S1,P00 [....,....,Nmask,j1]
967
1314
  if s1 is None:
968
- s1=l_s1
969
- p00=l_p00
1315
+ s1 = l_s1
1316
+ p00 = l_p00
970
1317
  else:
971
- s1=self.backend.bk_concat([s1,l_s1],axis=-1)
972
- p00=self.backend.bk_concat([p00,l_p00],axis=-1)
1318
+ s1 = self.backend.bk_concat([s1, l_s1], axis=-1)
1319
+ p00 = self.backend.bk_concat([p00, l_p00], axis=-1)
973
1320
 
974
1321
  # Concat l2_image [....,Npix_j1,....,j1]
975
1322
  if l2_image is None:
976
- l2_image=self.backend.bk_expand_dims(conj,axis=1)
1323
+ l2_image = self.backend.bk_expand_dims(conj, axis=1)
977
1324
  else:
978
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=1),l2_image],axis=1)
1325
+ l2_image = self.backend.bk_concat(
1326
+ [self.backend.bk_expand_dims(conj, axis=1), l2_image], axis=1
1327
+ )
979
1328
 
980
1329
  # Convol l2_image [....,Npix_j1,....,j1,Norient,Norient]
981
- c2_image=self.convol_1d(self.backend.bk_relu(l2_image),axis=axis+2)
1330
+ c2_image = self.convol_1d(self.backend.bk_relu(l2_image), axis=axis + 2)
982
1331
 
983
- conj2p=c2_image*self.backend.bk_conjugate(c2_image)
984
- conj2pl1=self.backend.bk_L1(conj2p)
1332
+ conj2p = c2_image * self.backend.bk_conjugate(c2_image)
1333
+ conj2pl1 = self.backend.bk_L1(conj2p)
985
1334
 
986
1335
  if Auto:
987
- conj2p=self.backend.bk_real(conj2p)
988
- conj2pl1=self.backend.bk_real(conj2pl1)
1336
+ conj2p = self.backend.bk_real(conj2p)
1337
+ conj2pl1 = self.backend.bk_real(conj2pl1)
989
1338
 
990
- c2_image=self.convol_1d(self.backend.bk_relu(-l2_image),axis=axis+2)
1339
+ c2_image = self.convol_1d(self.backend.bk_relu(-l2_image), axis=axis + 2)
991
1340
 
992
- conj2m=c2_image*self.backend.bk_conjugate(c2_image)
993
- conj2ml1=self.backend.bk_L1(conj2m)
1341
+ conj2m = c2_image * self.backend.bk_conjugate(c2_image)
1342
+ conj2ml1 = self.backend.bk_L1(conj2m)
994
1343
 
995
1344
  if Auto:
996
- conj2m=self.backend.bk_real(conj2m)
997
- conj2ml1=self.backend.bk_real(conj2ml1)
998
-
1345
+ conj2m = self.backend.bk_real(conj2m)
1346
+ conj2ml1 = self.backend.bk_real(conj2ml1)
1347
+
999
1348
  # Convol l_s2 [....,....,Nmask,j1]
1000
- l_s2 = self.backend.bk_reduce_sum((conj2p-conj2m)*self.backend.bk_expand_dims(vmask,1),axis=axis+2)
1001
- l_s2l1 = self.backend.bk_reduce_sum((conj2pl1-conj2ml1)*self.backend.bk_expand_dims(vmask,1),axis=axis+2)
1349
+ l_s2 = self.backend.bk_reduce_sum(
1350
+ (conj2p - conj2m) * self.backend.bk_expand_dims(vmask, 1), axis=axis + 2
1351
+ )
1352
+ l_s2l1 = self.backend.bk_reduce_sum(
1353
+ (conj2pl1 - conj2ml1) * self.backend.bk_expand_dims(vmask, 1),
1354
+ axis=axis + 2,
1355
+ )
1002
1356
 
1003
1357
  # Concat l_s2 [....,....,Nmask,j1*(j1+1)/2,Norient,Norient]
1004
1358
  if s2 is None:
1005
- s2l=l_s2
1006
- s2=l_s2l1
1007
- s2j1=np.arange(l_s2.shape[axis],dtype='int')
1008
- s2j2=j1*np.ones(l_s2.shape[axis],dtype='int')
1359
+ s2l = l_s2
1360
+ s2 = l_s2l1
1361
+ s2j1 = np.arange(l_s2.shape[axis], dtype="int")
1362
+ s2j2 = j1 * np.ones(l_s2.shape[axis], dtype="int")
1009
1363
  else:
1010
- s2=self.backend.bk_concat([s2,l_s2l1],axis=-1)
1011
- s2l=self.backend.bk_concat([s2l,l_s2],axis=-1)
1012
- s2j1=np.concatenate([s2j1,np.arange(l_s2.shape[axis+1],dtype='int')],0)
1013
- s2j2=np.concatenate([s2j2,j1*np.ones(l_s2.shape[axis+1],dtype='int')],0)
1014
-
1015
- if j1!=jmax-1:
1016
- # Rescale vmask [Nmask,Npix_j1//4]
1017
- vmask = self.smooth_1d(vmask,axis=1)
1018
- vmask = self.ud_grade_1d(vmask,vmask.shape[1]//2,axis=1)
1364
+ s2 = self.backend.bk_concat([s2, l_s2l1], axis=-1)
1365
+ s2l = self.backend.bk_concat([s2l, l_s2], axis=-1)
1366
+ s2j1 = np.concatenate(
1367
+ [s2j1, np.arange(l_s2.shape[axis + 1], dtype="int")], 0
1368
+ )
1369
+ s2j2 = np.concatenate(
1370
+ [s2j2, j1 * np.ones(l_s2.shape[axis + 1], dtype="int")], 0
1371
+ )
1372
+
1373
+ if j1 != jmax - 1:
1374
+ # Rescale vmask [Nmask,Npix_j1//4]
1375
+ vmask = self.smooth_1d(vmask, axis=1)
1376
+ vmask = self.ud_grade_1d(vmask, vmask.shape[1] // 2, axis=1)
1019
1377
  if self.mask_thres is not None:
1020
- vmask = self.backend.bk_threshold(vmask,self.mask_thres)
1021
-
1022
- # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1023
- l2_image = self.smooth_1d(l2_image,axis=axis+2)
1024
- l2_image = self.ud_grade_1d(l2_image,l2_image.shape[axis+2]//2,axis=axis+2)
1025
-
1026
- # Rescale l_image [....,Npix_j1//4,....]
1027
- l_image1 = self.smooth_1d(l_image1,axis=axis+1)
1028
- l_image1 = self.ud_grade_1d(l_image1,l_image1.shape[axis+1]//2,axis=axis+1)
1378
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
1379
+
1380
+ # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1381
+ l2_image = self.smooth_1d(l2_image, axis=axis + 2)
1382
+ l2_image = self.ud_grade_1d(
1383
+ l2_image, l2_image.shape[axis + 2] // 2, axis=axis + 2
1384
+ )
1385
+
1386
+ # Rescale l_image [....,Npix_j1//4,....]
1387
+ l_image1 = self.smooth_1d(l_image1, axis=axis + 1)
1388
+ l_image1 = self.ud_grade_1d(
1389
+ l_image1, l_image1.shape[axis + 1] // 2, axis=axis + 1
1390
+ )
1029
1391
  if cross:
1030
- l_image2 = self.smooth_1d(l_image2,axis=axis+2)
1031
- l_image2 = self.ud_grade_1d(l_image2,l_image2.shape[axis+2]//2,axis=axis+2)
1392
+ l_image2 = self.smooth_1d(l_image2, axis=axis + 2)
1393
+ l_image2 = self.ud_grade_1d(
1394
+ l_image2, l_image2.shape[axis + 2] // 2, axis=axis + 2
1395
+ )
1032
1396
 
1033
- return(scat1D(p00,s0,s1,s2,s2l,s2j1,s2j2,cross=cross,backend=self.backend))
1397
+ return scat1D(
1398
+ p00, s0, s1, s2, s2l, s2j1, s2j2, cross=cross, backend=self.backend
1399
+ )
1034
1400
 
1035
- def square(self,x):
1401
+ def square(self, x):
1036
1402
  # the abs make the complex value usable for reduce_sum or mean
1037
- return scat1D(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1038
- self.backend.bk_square(self.backend.bk_abs(x.S0)),
1039
- self.backend.bk_square(self.backend.bk_abs(x.S1)),
1040
- self.backend.bk_square(self.backend.bk_abs(x.S2)),
1041
- self.backend.bk_square(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1042
-
1043
- def sqrt(self,x):
1403
+ return scat1D(
1404
+ self.backend.bk_square(self.backend.bk_abs(x.P00)),
1405
+ self.backend.bk_square(self.backend.bk_abs(x.S0)),
1406
+ self.backend.bk_square(self.backend.bk_abs(x.S1)),
1407
+ self.backend.bk_square(self.backend.bk_abs(x.S2)),
1408
+ self.backend.bk_square(self.backend.bk_abs(x.S2L)),
1409
+ x.j1,
1410
+ x.j2,
1411
+ backend=self.backend,
1412
+ )
1413
+
1414
+ def sqrt(self, x):
1044
1415
  # the abs make the complex value usable for reduce_sum or mean
1045
- return scat1D(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1046
- self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1047
- self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1048
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1049
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1050
-
1051
- def reduce_mean(self,x,axis=None):
1416
+ return scat1D(
1417
+ self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1418
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1419
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1420
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1421
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),
1422
+ x.j1,
1423
+ x.j2,
1424
+ backend=self.backend,
1425
+ )
1426
+
1427
+ def reduce_mean(self, x, axis=None):
1052
1428
  if axis is None:
1053
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))+ \
1054
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))+ \
1055
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))+ \
1056
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))+ \
1057
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1058
-
1059
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1060
- np.array(list(x.S0.shape)).prod()+ \
1061
- np.array(list(x.S1.shape)).prod()+ \
1062
- np.array(list(x.S2.shape)).prod()
1063
-
1064
- return tmp/ntmp
1429
+ tmp = (
1430
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))
1431
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))
1432
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))
1433
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))
1434
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1435
+ )
1436
+
1437
+ ntmp = (
1438
+ np.array(list(x.P00.shape)).prod()
1439
+ + np.array(list(x.S0.shape)).prod()
1440
+ + np.array(list(x.S1.shape)).prod()
1441
+ + np.array(list(x.S2.shape)).prod()
1442
+ )
1443
+
1444
+ return tmp / ntmp
1065
1445
  else:
1066
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00,axis=axis))+ \
1067
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0,axis=axis))+ \
1068
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1,axis=axis))+ \
1069
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2,axis=axis))+ \
1070
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L,axis=axis))
1071
-
1072
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1073
- np.array(list(x.S0.shape)).prod()+ \
1074
- np.array(list(x.S1.shape)).prod()+ \
1075
- np.array(list(x.S2.shape)).prod()+ \
1076
- np.array(list(x.S2L.shape)).prod()
1077
-
1078
- return tmp/ntmp
1079
-
1080
- def reduce_sum(self,x,axis=None):
1446
+ tmp = (
1447
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00, axis=axis))
1448
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0, axis=axis))
1449
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1, axis=axis))
1450
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2, axis=axis))
1451
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L, axis=axis))
1452
+ )
1453
+
1454
+ ntmp = (
1455
+ np.array(list(x.P00.shape)).prod()
1456
+ + np.array(list(x.S0.shape)).prod()
1457
+ + np.array(list(x.S1.shape)).prod()
1458
+ + np.array(list(x.S2.shape)).prod()
1459
+ + np.array(list(x.S2L.shape)).prod()
1460
+ )
1461
+
1462
+ return tmp / ntmp
1463
+
1464
+ def reduce_sum(self, x, axis=None):
1081
1465
  if axis is None:
1082
- return self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))+ \
1083
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))+ \
1084
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))+ \
1085
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))+ \
1086
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1466
+ return (
1467
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))
1468
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
1469
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
1470
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
1471
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1472
+ )
1087
1473
  else:
1088
- return scat1D(self.backend.bk_reduce_sum(x.P00,axis=axis),
1089
- self.backend.bk_reduce_sum(x.S0,axis=axis),
1090
- self.backend.bk_reduce_sum(x.S1,axis=axis),
1091
- self.backend.bk_reduce_sum(x.S2,axis=axis),
1092
- self.backend.bk_reduce_sum(x.S2L,axis=axis),x.j1,x.j2,backend=self.backend)
1093
-
1094
- def ldiff(self,sig,x):
1095
- return scat1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1096
- x.domult(sig.S0,x.S0)*x.domult(sig.S0,x.S0),
1097
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1098
- x.domult(sig.S2,x.S2)*x.domult(sig.S2,x.S2),
1099
- x.domult(sig.S2L,x.S2L)*x.domult(sig.S2L,x.S2L),x.j1,x.j2,backend=self.backend)
1100
-
1101
- def log(self,x):
1102
- return scat1D(self.backend.bk_log(x.P00),
1103
- self.backend.bk_log(x.S0),
1104
- self.backend.bk_log(x.S1),
1105
- self.backend.bk_log(x.S2),
1106
- self.backend.bk_log(x.S2L),x.j1,x.j2,backend=self.backend)
1107
- def abs(self,x):
1108
- return scat1D(self.backend.bk_abs(x.P00),
1109
- self.backend.bk_abs(x.S0),
1110
- self.backend.bk_abs(x.S1),
1111
- self.backend.bk_abs(x.S2),
1112
- self.backend.bk_abs(x.S2L),x.j1,x.j2,backend=self.backend)
1113
- def inv(self,x):
1114
- return scat1D(1/(x.P00),1/(x.S0),1/(x.S1),1/(x.S2),1/(x.S2L),x.j1,x.j2,backend=self.backend)
1474
+ return scat1D(
1475
+ self.backend.bk_reduce_sum(x.P00, axis=axis),
1476
+ self.backend.bk_reduce_sum(x.S0, axis=axis),
1477
+ self.backend.bk_reduce_sum(x.S1, axis=axis),
1478
+ self.backend.bk_reduce_sum(x.S2, axis=axis),
1479
+ self.backend.bk_reduce_sum(x.S2L, axis=axis),
1480
+ x.j1,
1481
+ x.j2,
1482
+ backend=self.backend,
1483
+ )
1484
+
1485
+ def ldiff(self, sig, x):
1486
+ return scat1D(
1487
+ x.domult(sig.P00, x.P00) * x.domult(sig.P00, x.P00),
1488
+ x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
1489
+ x.domult(sig.S1, x.S1) * x.domult(sig.S1, x.S1),
1490
+ x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
1491
+ x.domult(sig.S2L, x.S2L) * x.domult(sig.S2L, x.S2L),
1492
+ x.j1,
1493
+ x.j2,
1494
+ backend=self.backend,
1495
+ )
1496
+
1497
+ def log(self, x):
1498
+ return scat1D(
1499
+ self.backend.bk_log(x.P00),
1500
+ self.backend.bk_log(x.S0),
1501
+ self.backend.bk_log(x.S1),
1502
+ self.backend.bk_log(x.S2),
1503
+ self.backend.bk_log(x.S2L),
1504
+ x.j1,
1505
+ x.j2,
1506
+ backend=self.backend,
1507
+ )
1508
+
1509
+ def abs(self, x):
1510
+ return scat1D(
1511
+ self.backend.bk_abs(x.P00),
1512
+ self.backend.bk_abs(x.S0),
1513
+ self.backend.bk_abs(x.S1),
1514
+ self.backend.bk_abs(x.S2),
1515
+ self.backend.bk_abs(x.S2L),
1516
+ x.j1,
1517
+ x.j2,
1518
+ backend=self.backend,
1519
+ )
1520
+
1521
+ def inv(self, x):
1522
+ return scat1D(
1523
+ 1 / (x.P00),
1524
+ 1 / (x.S0),
1525
+ 1 / (x.S1),
1526
+ 1 / (x.S2),
1527
+ 1 / (x.S2L),
1528
+ x.j1,
1529
+ x.j2,
1530
+ backend=self.backend,
1531
+ )
1115
1532
 
1116
1533
  def one(self):
1117
- return scat1D(1.0,1.0,1.0,1.0,1.0,[0],[0],backend=self.backend)
1534
+ return scat1D(1.0, 1.0, 1.0, 1.0, 1.0, [0], [0], backend=self.backend)
1118
1535
 
1119
- def eval_comp_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1536
+ def eval_comp_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1120
1537
 
1121
- res=self.eval_fast(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1122
- return res.P00,res.S0,res.S1,res.S2,res.S2L,res.j1,res.j2
1538
+ res = self.eval_fast(image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off)
1539
+ return res.P00, res.S0, res.S1, res.S2, res.S2L, res.j1, res.j2
1123
1540
 
1124
1541
  @tf_function
1125
- def eval_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1126
- p0,s0,s1,s2,s2l,j1,j2=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1127
- return scat1D(p0,s0,s1,s2,s2l,j1,j2,backend=self.backend)
1128
-
1129
-
1130
-
1542
+ def eval_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1543
+ p0, s0, s1, s2, s2l, j1, j2 = self.eval_comp_fast(
1544
+ image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off
1545
+ )
1546
+ return scat1D(p0, s0, s1, s2, s2l, j1, j2, backend=self.backend)