foscat 3.1.6__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat1D.py CHANGED
@@ -1,453 +1,608 @@
1
- import foscat.FoCUS as FOC
2
- import numpy as np
3
1
  import pickle
4
- import foscat.backend as bk
5
2
  import sys
6
3
 
4
+ import numpy as np
5
+
6
+ import foscat.backend as bk
7
+ import foscat.FoCUS as FOC
8
+
7
9
  # Vérifier si TensorFlow est importé et défini
8
- tf_defined = 'tensorflow' in sys.modules
10
+ tf_defined = "tensorflow" in sys.modules
9
11
 
10
12
  if tf_defined:
11
13
  import tensorflow as tf
12
- tf_function = tf.function # Facultatif : si vous voulez utiliser TensorFlow dans ce script
14
+
15
+ tf_function = (
16
+ tf.function
17
+ ) # Facultatif : si vous voulez utiliser TensorFlow dans ce script
13
18
  else:
19
+
14
20
  def tf_function(func):
15
21
  return func
16
-
22
+
23
+
17
24
  def read(filename):
18
- thescat=scat1D(1,1,1,1,1,[0],[0])
25
+ thescat = scat1D(1, 1, 1, 1, 1, [0], [0])
19
26
  return thescat.read(filename)
20
-
27
+
28
+
21
29
  class scat1D:
22
- def __init__(self,p00,s0,s1,s2,s2l,j1,j2,cross=False,backend=None):
23
- self.P00=p00
24
- self.S0=s0
25
- self.S1=s1
26
- self.S2=s2
27
- self.S2L=s2l
28
- self.j1=j1
29
- self.j2=j2
30
- self.cross=cross
31
- self.backend=backend
30
+ def __init__(self, p00, s0, s1, s2, s2l, j1, j2, cross=False, backend=None):
31
+ self.P00 = p00
32
+ self.S0 = s0
33
+ self.S1 = s1
34
+ self.S2 = s2
35
+ self.S2L = s2l
36
+ self.j1 = j1
37
+ self.j2 = j2
38
+ self.cross = cross
39
+ self.backend = backend
32
40
 
33
41
  # ---------------------------------------------−---------
34
- def build_flat(self,table):
35
- shape=table.shape
36
- ndata=1
37
- for k in range(1,len(table.shape)):
38
- ndata=ndata*table.shape[k]
39
- return self.backend.bk_reshape(table,[table.shape[0],ndata])
40
-
42
+ def build_flat(self, table):
43
+ shape = table.shape
44
+ ndata = 1
45
+ for k in range(1, len(shape)):
46
+ ndata = ndata * shape[k]
47
+ return self.backend.bk_reshape(table, [shape[0], ndata])
48
+
41
49
  # ---------------------------------------------−---------
42
- def flatten(self,S2L=True,P00=True):
43
- tmp=[self.build_flat(self.S0),self.build_flat(self.S1)]
44
-
50
+ def flatten(self, S2L=True, P00=True):
51
+ tmp = [self.build_flat(self.S0), self.build_flat(self.S1)]
52
+
45
53
  if P00:
46
- tmp=tmp+[self.build_flat(self.P00)]
54
+ tmp = tmp + [self.build_flat(self.P00)]
55
+
56
+ tmp = tmp + [self.build_flat(self.S2)]
47
57
 
48
- tmp=tmp+[self.build_flat(self.S2)]
49
-
50
58
  if S2L:
51
- tmp=tmp+[self.build_flat(self.S2L)]
52
-
53
- if isinstance(self.P00,np.ndarray):
54
- return np.concatenate(tmp,1)
59
+ tmp = tmp + [self.build_flat(self.S2L)]
60
+
61
+ if isinstance(self.P00, np.ndarray):
62
+ return np.concatenate(tmp, 1)
55
63
  else:
56
- return self.backend.bk_concat(tmp,1)
64
+ return self.backend.bk_concat(tmp, 1)
65
+
57
66
  # ---------------------------------------------−---------
58
- def flatten_name(self,S2L=True,P00=True):
59
-
60
- tmp=['S0']
67
+ def flatten_name(self, S2L=True, P00=True):
68
+
69
+ tmp = ["S0"]
70
+
71
+ tmp = tmp + ["S1_%d" % (k) for k in range(self.S1.shape[-1])]
61
72
 
62
-
63
- tmp=tmp+['S1_%d'%(k) for k in range(self.S1.shape[-1])]
64
-
65
73
  if P00:
66
- tmp=tmp+['P00_%d'%(k) for k in range(self.P00.shape[-1])]
74
+ tmp = tmp + ["P00_%d" % (k) for k in range(self.P00.shape[-1])]
75
+
76
+ j1, j2 = self.get_j_idx()
77
+
78
+ tmp = tmp + ["S2_%d-%d" % (j1[k], j2[k]) for k in range(self.S2.shape[-1])]
67
79
 
68
- j1,j2=self.get_j_idx()
69
-
70
- tmp=tmp+['S2_%d-%d'%(j1[k],j2[k]) for k in range(self.S2.shape[-1])]
71
-
72
80
  if S2L:
73
- tmp=tmp+['S2L_%d-%d'%(j1[k],j2[k]) for k in range(self.S2.shape[-1])]
74
-
81
+ tmp = tmp + ["S2L_%d-%d" % (j1[k], j2[k]) for k in range(self.S2.shape[-1])]
82
+
75
83
  return tmp
76
-
84
+
77
85
  def get_j_idx(self):
78
- return self.j1,self.j2
79
-
86
+ return self.j1, self.j2
87
+
80
88
  def get_S0(self):
81
- return(self.S0)
89
+ return self.S0
82
90
 
83
91
  def get_S1(self):
84
- return(self.S1)
85
-
92
+ return self.S1
93
+
86
94
  def get_S2(self):
87
- return(self.S2)
95
+ return self.S2
88
96
 
89
97
  def get_S2L(self):
90
- return(self.S2L)
98
+ return self.S2L
91
99
 
92
100
  def get_P00(self):
93
- return(self.P00)
101
+ return self.P00
94
102
 
95
103
  def reset_P00(self):
96
- self.P00=0*self.P00
104
+ self.P00 = 0 * self.P00
97
105
 
98
106
  def constant(self):
99
- return scat1D(self.backend.constant(self.P00 ), \
100
- self.backend.constant(self.S0 ), \
101
- self.backend.constant(self.S1 ), \
102
- self.backend.constant(self.S2 ), \
103
- self.backend.constant(self.S2L ), \
104
- self.j1 , \
105
- self.j2 ,backend=self.backend)
106
-
107
- def domult(self,x,y):
108
- if x.dtype==y.dtype:
109
- return x*y
110
-
107
+ return scat1D(
108
+ self.backend.constant(self.P00),
109
+ self.backend.constant(self.S0),
110
+ self.backend.constant(self.S1),
111
+ self.backend.constant(self.S2),
112
+ self.backend.constant(self.S2L),
113
+ self.j1,
114
+ self.j2,
115
+ backend=self.backend,
116
+ )
117
+
118
+ def domult(self, x, y):
119
+ if x.dtype == y.dtype:
120
+ return x * y
121
+
111
122
  if self.backend.bk_is_complex(x):
112
-
113
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
123
+
124
+ return self.backend.bk_complex(
125
+ self.backend.bk_real(x) * y, self.backend.bk_imag(x) * y
126
+ )
114
127
  else:
115
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
116
-
117
- def dodiv(self,x,y):
118
- if x.dtype==y.dtype:
119
- return x/y
128
+ return self.backend.bk_complex(
129
+ self.backend.bk_real(y) * x, self.backend.bk_imag(y) * x
130
+ )
131
+
132
+ def dodiv(self, x, y):
133
+ if x.dtype == y.dtype:
134
+ return x / y
120
135
  if self.backend.bk_is_complex(x):
121
-
122
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
136
+
137
+ return self.backend.bk_complex(
138
+ self.backend.bk_real(x) / y, self.backend.bk_imag(x) / y
139
+ )
123
140
  else:
124
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
125
-
126
- def domin(self,x,y):
127
- if x.dtype==y.dtype:
128
- return x-y
129
-
141
+ return self.backend.bk_complex(
142
+ x / self.backend.bk_real(y), x / self.backend.bk_imag(y)
143
+ )
144
+
145
+ def domin(self, x, y):
146
+ if x.dtype == y.dtype:
147
+ return x - y
148
+
130
149
  if self.backend.bk_is_complex(x):
131
-
132
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
150
+
151
+ return self.backend.bk_complex(
152
+ self.backend.bk_real(x) - y, self.backend.bk_imag(x) - y
153
+ )
133
154
  else:
134
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
135
-
136
- def doadd(self,x,y):
137
- if x.dtype==y.dtype:
138
- return x+y
139
-
155
+ return self.backend.bk_complex(
156
+ x - self.backend.bk_real(y), x - self.backend.bk_imag(y)
157
+ )
158
+
159
+ def doadd(self, x, y):
160
+ if x.dtype == y.dtype:
161
+ return x + y
162
+
140
163
  if self.backend.bk_is_complex(x):
141
-
142
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
164
+
165
+ return self.backend.bk_complex(
166
+ self.backend.bk_real(x) + y, self.backend.bk_imag(x) + y
167
+ )
143
168
  else:
144
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
145
-
146
- def relu(self):
147
-
148
- return scat1D(self.backend.bk_relu(self.P00), \
149
- self.backend.bk_relu(self.S0), \
150
- self.backend.bk_relu(self.S1), \
151
- self.backend.bk_relu(self.S2), \
152
- self.backend.bk_relu(self.S2L), \
153
- self.j1,self.j2,backend=self.backend)
154
-
155
- def __add__(self,other):
156
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
157
- isinstance(other, bool) or isinstance(other, scat1D)
158
-
169
+ return self.backend.bk_complex(
170
+ x + self.backend.bk_real(y), x + self.backend.bk_imag(y)
171
+ )
172
+
173
+ def __add__(self, other):
174
+ assert (
175
+ isinstance(other, float)
176
+ or isinstance(other, np.float32)
177
+ or isinstance(other, int)
178
+ or isinstance(other, bool)
179
+ or isinstance(other, scat1D)
180
+ )
181
+
159
182
  if isinstance(other, scat1D):
160
- return scat1D(self.doadd(self.P00,other.P00), \
161
- self.doadd(self.S0, other.S0), \
162
- self.doadd(self.S1, other.S1), \
163
- self.doadd(self.S2, other.S2), \
164
- self.doadd(self.S2L, other.S2L), \
165
- self.j1,self.j2,backend=self.backend)
183
+ return scat1D(
184
+ self.doadd(self.P00, other.P00),
185
+ self.doadd(self.S0, other.S0),
186
+ self.doadd(self.S1, other.S1),
187
+ self.doadd(self.S2, other.S2),
188
+ self.doadd(self.S2L, other.S2L),
189
+ self.j1,
190
+ self.j2,
191
+ backend=self.backend,
192
+ )
166
193
  else:
167
- return scat1D((self.P00+ other), \
168
- (self.S0+ other), \
169
- (self.S1+ other), \
170
- (self.S2+ other), \
171
- (self.S2L+ other), \
172
- self.j1,self.j2,backend=self.backend)
173
-
174
- def toreal(self,value):
194
+ return scat1D(
195
+ (self.P00 + other),
196
+ (self.S0 + other),
197
+ (self.S1 + other),
198
+ (self.S2 + other),
199
+ (self.S2L + other),
200
+ self.j1,
201
+ self.j2,
202
+ backend=self.backend,
203
+ )
204
+
205
+ def toreal(self, value):
175
206
  if value is None:
176
207
  return None
177
-
208
+
178
209
  return self.backend.bk_real(value)
179
210
 
180
- def addcomplex(self,value,amp):
211
+ def addcomplex(self, value, amp):
181
212
  if value is None:
182
213
  return None
183
-
184
- return self.backend.bk_complex(value,amp*value)
185
-
186
- def add_complex(self,amp):
187
- return scat1D(self.addcomplex(self.P00,amp), \
188
- self.addcomplex(self.S0,amp), \
189
- self.addcomplex(self.S1,amp), \
190
- self.addcomplex(self.S2,amp), \
191
- self.addcomplex(self.S2L,amp), \
192
- self.j1,self.j2,backend=self.backend)
193
-
214
+
215
+ return self.backend.bk_complex(value, amp * value)
216
+
217
+ def add_complex(self, amp):
218
+ return scat1D(
219
+ self.addcomplex(self.P00, amp),
220
+ self.addcomplex(self.S0, amp),
221
+ self.addcomplex(self.S1, amp),
222
+ self.addcomplex(self.S2, amp),
223
+ self.addcomplex(self.S2L, amp),
224
+ self.j1,
225
+ self.j2,
226
+ backend=self.backend,
227
+ )
228
+
194
229
  def real(self):
195
- return scat1D(self.toreal(self.P00), \
196
- self.toreal(self.S0), \
197
- self.toreal(self.S1), \
198
- self.toreal(self.S2), \
199
- self.toreal(self.S2L), \
200
- self.j1,self.j2,backend=self.backend)
201
-
202
- def __radd__(self,other):
230
+ return scat1D(
231
+ self.toreal(self.P00),
232
+ self.toreal(self.S0),
233
+ self.toreal(self.S1),
234
+ self.toreal(self.S2),
235
+ self.toreal(self.S2L),
236
+ self.j1,
237
+ self.j2,
238
+ backend=self.backend,
239
+ )
240
+
241
+ def __radd__(self, other):
203
242
  return self.__add__(other)
204
243
 
205
- def __truediv__(self,other):
206
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
207
- isinstance(other, bool) or isinstance(other, scat1D)
208
-
244
+ def __truediv__(self, other):
245
+ assert (
246
+ isinstance(other, float)
247
+ or isinstance(other, np.float32)
248
+ or isinstance(other, int)
249
+ or isinstance(other, bool)
250
+ or isinstance(other, scat1D)
251
+ )
252
+
209
253
  if isinstance(other, scat1D):
210
- return scat1D(self.dodiv(self.P00, other.P00), \
211
- self.dodiv(self.S0, other.S0), \
212
- self.dodiv(self.S1, other.S1), \
213
- self.dodiv(self.S2, other.S2), \
214
- self.dodiv(self.S2L, other.S2L), \
215
- self.j1,self.j2,backend=self.backend)
254
+ return scat1D(
255
+ self.dodiv(self.P00, other.P00),
256
+ self.dodiv(self.S0, other.S0),
257
+ self.dodiv(self.S1, other.S1),
258
+ self.dodiv(self.S2, other.S2),
259
+ self.dodiv(self.S2L, other.S2L),
260
+ self.j1,
261
+ self.j2,
262
+ backend=self.backend,
263
+ )
216
264
  else:
217
- return scat1D((self.P00/ other), \
218
- (self.S0/ other), \
219
- (self.S1/ other), \
220
- (self.S2/ other), \
221
- (self.S2L/ other), \
222
- self.j1,self.j2,backend=self.backend)
223
-
224
-
225
- def __rtruediv__(self,other):
226
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
227
- isinstance(other, bool) or isinstance(other, scat1D)
228
-
265
+ return scat1D(
266
+ (self.P00 / other),
267
+ (self.S0 / other),
268
+ (self.S1 / other),
269
+ (self.S2 / other),
270
+ (self.S2L / other),
271
+ self.j1,
272
+ self.j2,
273
+ backend=self.backend,
274
+ )
275
+
276
+ def __rtruediv__(self, other):
277
+ assert (
278
+ isinstance(other, float)
279
+ or isinstance(other, np.float32)
280
+ or isinstance(other, int)
281
+ or isinstance(other, bool)
282
+ or isinstance(other, scat1D)
283
+ )
284
+
229
285
  if isinstance(other, scat1D):
230
- return scat1D(self.dodiv(other.P00, self.P00), \
231
- self.dodiv(other.S0 , self.S0), \
232
- self.dodiv(other.S1 , self.S1), \
233
- self.dodiv(other.S2 , self.S2), \
234
- self.dodiv(other.S2L , self.S2L), \
235
- self.j1,self.j2,backend=self.backend)
286
+ return scat1D(
287
+ self.dodiv(other.P00, self.P00),
288
+ self.dodiv(other.S0, self.S0),
289
+ self.dodiv(other.S1, self.S1),
290
+ self.dodiv(other.S2, self.S2),
291
+ self.dodiv(other.S2L, self.S2L),
292
+ self.j1,
293
+ self.j2,
294
+ backend=self.backend,
295
+ )
236
296
  else:
237
- return scat1D((other/ self.P00), \
238
- (other / self.S0), \
239
- (other / self.S1), \
240
- (other / self.S2), \
241
- (other / self.S2L), \
242
- self.j1,self.j2,backend=self.backend)
243
-
244
- def __sub__(self,other):
245
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
246
- isinstance(other, bool) or isinstance(other, scat1D)
247
-
297
+ return scat1D(
298
+ (other / self.P00),
299
+ (other / self.S0),
300
+ (other / self.S1),
301
+ (other / self.S2),
302
+ (other / self.S2L),
303
+ self.j1,
304
+ self.j2,
305
+ backend=self.backend,
306
+ )
307
+
308
+ def __sub__(self, other):
309
+ assert (
310
+ isinstance(other, float)
311
+ or isinstance(other, np.float32)
312
+ or isinstance(other, int)
313
+ or isinstance(other, bool)
314
+ or isinstance(other, scat1D)
315
+ )
316
+
248
317
  if isinstance(other, scat1D):
249
- return scat1D(self.domin(self.P00, other.P00), \
250
- self.domin(self.S0, other.S0), \
251
- self.domin(self.S1, other.S1), \
252
- self.domin(self.S2, other.S2), \
253
- self.domin(self.S2L, other.S2L), \
254
- self.j1,self.j2,backend=self.backend)
318
+ return scat1D(
319
+ self.domin(self.P00, other.P00),
320
+ self.domin(self.S0, other.S0),
321
+ self.domin(self.S1, other.S1),
322
+ self.domin(self.S2, other.S2),
323
+ self.domin(self.S2L, other.S2L),
324
+ self.j1,
325
+ self.j2,
326
+ backend=self.backend,
327
+ )
255
328
  else:
256
- return scat1D((self.P00- other), \
257
- (self.S0- other), \
258
- (self.S1- other), \
259
- (self.S2- other), \
260
- (self.S2L- other), \
261
- self.j1,self.j2,backend=self.backend)
262
-
263
- def __rsub__(self,other):
264
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
265
- isinstance(other, bool) or isinstance(other, scat1D)
266
-
329
+ return scat1D(
330
+ (self.P00 - other),
331
+ (self.S0 - other),
332
+ (self.S1 - other),
333
+ (self.S2 - other),
334
+ (self.S2L - other),
335
+ self.j1,
336
+ self.j2,
337
+ backend=self.backend,
338
+ )
339
+
340
+ def __rsub__(self, other):
341
+ assert (
342
+ isinstance(other, float)
343
+ or isinstance(other, np.float32)
344
+ or isinstance(other, int)
345
+ or isinstance(other, bool)
346
+ or isinstance(other, scat1D)
347
+ )
348
+
267
349
  if isinstance(other, scat1D):
268
- return scat1D(self.domin(other.P00,self.P00), \
269
- self.domin(other.S0, self.S0), \
270
- self.domin(other.S1, self.S1), \
271
- self.domin(other.S2, self.S2), \
272
- self.domin(other.S2L, self.S2L), \
273
- self.j1,self.j2,backend=self.backend)
350
+ return scat1D(
351
+ self.domin(other.P00, self.P00),
352
+ self.domin(other.S0, self.S0),
353
+ self.domin(other.S1, self.S1),
354
+ self.domin(other.S2, self.S2),
355
+ self.domin(other.S2L, self.S2L),
356
+ self.j1,
357
+ self.j2,
358
+ backend=self.backend,
359
+ )
274
360
  else:
275
- return scat1D((other-self.P00), \
276
- (other-self.S0), \
277
- (other-self.S1), \
278
- (other-self.S2), \
279
- (other-self.S2L), \
280
- self.j1,self.j2,backend=self.backend)
281
-
282
- def __mul__(self,other):
283
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
284
- isinstance(other, bool) or isinstance(other, scat1D)
285
-
361
+ return scat1D(
362
+ (other - self.P00),
363
+ (other - self.S0),
364
+ (other - self.S1),
365
+ (other - self.S2),
366
+ (other - self.S2L),
367
+ self.j1,
368
+ self.j2,
369
+ backend=self.backend,
370
+ )
371
+
372
+ def __mul__(self, other):
373
+ assert (
374
+ isinstance(other, float)
375
+ or isinstance(other, np.float32)
376
+ or isinstance(other, int)
377
+ or isinstance(other, bool)
378
+ or isinstance(other, scat1D)
379
+ )
380
+
286
381
  if isinstance(other, scat1D):
287
- return scat1D(self.domult(self.P00, other.P00), \
288
- self.domult(self.S0, other.S0), \
289
- self.domult(self.S1, other.S1), \
290
- self.domult(self.S2, other.S2), \
291
- self.domult(self.S2L, other.S2L), \
292
- self.j1,self.j2,backend=self.backend)
382
+ return scat1D(
383
+ self.domult(self.P00, other.P00),
384
+ self.domult(self.S0, other.S0),
385
+ self.domult(self.S1, other.S1),
386
+ self.domult(self.S2, other.S2),
387
+ self.domult(self.S2L, other.S2L),
388
+ self.j1,
389
+ self.j2,
390
+ backend=self.backend,
391
+ )
293
392
  else:
294
- return scat1D((self.P00* other), \
295
- (self.S0* other), \
296
- (self.S1* other), \
297
- (self.S2* other), \
298
- (self.S2L* other), \
299
- self.j1,self.j2,backend=self.backend)
393
+ return scat1D(
394
+ (self.P00 * other),
395
+ (self.S0 * other),
396
+ (self.S1 * other),
397
+ (self.S2 * other),
398
+ (self.S2L * other),
399
+ self.j1,
400
+ self.j2,
401
+ backend=self.backend,
402
+ )
403
+
300
404
  def relu(self):
301
- return scat1D(self.backend.bk_relu(self.P00),
302
- self.backend.bk_relu(self.S0),
303
- self.backend.bk_relu(self.S1),
304
- self.backend.bk_relu(self.S2),
305
- self.backend.bk_relu(self.S2L), \
306
- self.j1,self.j2,backend=self.backend)
307
-
308
-
309
- def __rmul__(self,other):
310
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
311
- isinstance(other, bool) or isinstance(other, scat1D)
312
-
405
+ return scat1D(
406
+ self.backend.bk_relu(self.P00),
407
+ self.backend.bk_relu(self.S0),
408
+ self.backend.bk_relu(self.S1),
409
+ self.backend.bk_relu(self.S2),
410
+ self.backend.bk_relu(self.S2L),
411
+ self.j1,
412
+ self.j2,
413
+ backend=self.backend,
414
+ )
415
+
416
+ def __rmul__(self, other):
417
+ assert (
418
+ isinstance(other, float)
419
+ or isinstance(other, np.float32)
420
+ or isinstance(other, int)
421
+ or isinstance(other, bool)
422
+ or isinstance(other, scat1D)
423
+ )
424
+
313
425
  if isinstance(other, scat1D):
314
- return scat1D(self.domult(self.P00, other.P00), \
315
- self.domult(self.S0, other.S0), \
316
- self.domult(self.S1, other.S1), \
317
- self.domult(self.S2, other.S2), \
318
- self.domult(self.S2L, other.S2L), \
319
- self.j1,self.j2,backend=self.backend)
426
+ return scat1D(
427
+ self.domult(self.P00, other.P00),
428
+ self.domult(self.S0, other.S0),
429
+ self.domult(self.S1, other.S1),
430
+ self.domult(self.S2, other.S2),
431
+ self.domult(self.S2L, other.S2L),
432
+ self.j1,
433
+ self.j2,
434
+ backend=self.backend,
435
+ )
320
436
  else:
321
- return scat1D((self.P00* other), \
322
- (self.S0* other), \
323
- (self.S1* other), \
324
- (self.S2* other), \
325
- (self.S2L* other), \
326
- self.j1,self.j2,backend=self.backend)
327
-
328
- def l1_abs(self,x):
329
- y=self.get_np(x)
437
+ return scat1D(
438
+ (self.P00 * other),
439
+ (self.S0 * other),
440
+ (self.S1 * other),
441
+ (self.S2 * other),
442
+ (self.S2L * other),
443
+ self.j1,
444
+ self.j2,
445
+ backend=self.backend,
446
+ )
447
+
448
+ def l1_abs(self, x):
449
+ y = self.get_np(x)
330
450
  if self.backend.bk_is_complex(y):
331
- tmp=y.real*y.real+y.imag*y.imag
332
- tmp=np.sign(tmp)*np.sqrt(np.fabs(tmp))
333
- y=tmp
334
-
335
- return(y)
336
-
337
- def plot(self,name=None,hold=True,color='blue',lw=1,legend=True):
451
+ tmp = y.real * y.real + y.imag * y.imag
452
+ tmp = np.sign(tmp) * np.sqrt(np.fabs(tmp))
453
+ y = tmp
454
+
455
+ return y
456
+
457
+ def plot(self, name=None, hold=True, color="blue", lw=1, legend=True):
338
458
 
339
459
  import matplotlib.pyplot as plt
340
460
 
341
- j1,j2=self.get_j_idx()
342
-
461
+ j1, j2 = self.get_j_idx()
462
+
343
463
  if name is None:
344
- name=''
464
+ name = ""
345
465
 
346
466
  if hold:
347
- plt.figure(figsize=(16,8))
348
-
349
- test=None
467
+ plt.figure(figsize=(16, 8))
468
+
469
+ test = None
350
470
  plt.subplot(2, 2, 1)
351
- tmp=abs(self.get_np(self.S1))
352
- if len(tmp.shape)==4:
471
+ tmp = abs(self.get_np(self.S1))
472
+ if len(tmp.shape) == 4:
353
473
  for i1 in range(tmp.shape[0]):
354
474
  for i2 in range(tmp.shape[1]):
355
475
  if test is None:
356
- test=1
357
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
476
+ test = 1
477
+ plt.plot(
478
+ tmp[i1, i2, :],
479
+ color=color,
480
+ label=r"%s $S_{1}$" % (name),
481
+ lw=lw,
482
+ )
358
483
  else:
359
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
484
+ plt.plot(tmp[i1, i2, :], color=color, lw=lw)
360
485
  else:
361
486
  for i1 in range(tmp.shape[0]):
362
487
  if test is None:
363
- test=1
364
- plt.plot(tmp[i1,:],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
488
+ test = 1
489
+ plt.plot(
490
+ tmp[i1, :], color=color, label=r"%s $S_{1}$" % (name), lw=lw
491
+ )
365
492
  else:
366
- plt.plot(tmp[i1,:],color=color, lw=lw)
367
- plt.yscale('log')
368
- plt.ylabel('S1')
369
- plt.xlabel(r'$j_{1}$')
493
+ plt.plot(tmp[i1, :], color=color, lw=lw)
494
+ plt.yscale("log")
495
+ plt.ylabel("S1")
496
+ plt.xlabel(r"$j_{1}$")
370
497
  plt.legend()
371
498
 
372
- test=None
499
+ test = None
373
500
  plt.subplot(2, 2, 2)
374
- tmp=abs(self.get_np(self.P00))
375
- if len(tmp.shape)==4:
501
+ tmp = abs(self.get_np(self.P00))
502
+ if len(tmp.shape) == 4:
376
503
  for i1 in range(tmp.shape[0]):
377
504
  for i2 in range(tmp.shape[0]):
378
505
  if test is None:
379
- test=1
380
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
506
+ test = 1
507
+ plt.plot(
508
+ tmp[i1, i2, :],
509
+ color=color,
510
+ label=r"%s $P_{00}$" % (name),
511
+ lw=lw,
512
+ )
381
513
  else:
382
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
514
+ plt.plot(tmp[i1, i2, :], color=color, lw=lw)
383
515
  else:
384
516
  for i1 in range(tmp.shape[0]):
385
517
  if test is None:
386
- test=1
387
- plt.plot(tmp[i1,:],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
518
+ test = 1
519
+ plt.plot(
520
+ tmp[i1, :], color=color, label=r"%s $P_{00}$" % (name), lw=lw
521
+ )
388
522
  else:
389
- plt.plot(tmp[i1,:],color=color, lw=lw)
390
- plt.yscale('log')
391
- plt.ylabel('P00')
392
- plt.xlabel(r'$j_{1}$')
523
+ plt.plot(tmp[i1, :], color=color, lw=lw)
524
+ plt.yscale("log")
525
+ plt.ylabel("P00")
526
+ plt.xlabel(r"$j_{1}$")
393
527
  plt.legend()
394
-
395
- ax1=plt.subplot(2, 2, 3)
528
+
529
+ ax1 = plt.subplot(2, 2, 3)
396
530
  ax2 = ax1.twiny()
397
- n=0
398
- tmp=abs(self.get_np(self.S2))
399
- lname=r'%s $S_{2}$' % (name)
400
- ax1.set_ylabel(r'$S_{2}$ [L1 norm]')
401
- test=None
402
- tabx=[]
403
- tabnx=[]
404
- tab2x=[]
405
- tab2nx=[]
406
- if len(tmp.shape)==5:
531
+ n = 0
532
+ tmp = abs(self.get_np(self.S2))
533
+ lname = r"%s $S_{2}$" % (name)
534
+ ax1.set_ylabel(r"$S_{2}$ [L1 norm]")
535
+ test = None
536
+ tabx = []
537
+ tabnx = []
538
+ tab2x = []
539
+ tab2nx = []
540
+ if len(tmp.shape) == 5:
407
541
  for i0 in range(tmp.shape[0]):
408
542
  for i1 in range(tmp.shape[1]):
409
- for i2 in range(j1.max()+1):
410
- if j2[j1==i2].shape[0]==1:
411
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2],'.', \
412
- color=color, lw=lw)
543
+ for i2 in range(j1.max() + 1):
544
+ if j2[j1 == i2].shape[0] == 1:
545
+ ax1.plot(
546
+ j2[j1 == i2] + n,
547
+ tmp[i0, i1, j1 == i2],
548
+ ".",
549
+ color=color,
550
+ lw=lw,
551
+ )
413
552
  else:
414
553
  if legend and test is None:
415
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
416
- color=color, label=lname, lw=lw)
417
- test=1
418
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
419
- color=color, lw=lw)
420
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
421
- tabx=tabx+[k+n for k in j2[j1==i2]]
422
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
423
- tab2nx=tab2nx+['%d'%(i2)]
424
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
425
- n=n+j2[j1==i2].shape[0]-1
554
+ ax1.plot(
555
+ j2[j1 == i2] + n,
556
+ tmp[i0, i1, j1 == i2],
557
+ color=color,
558
+ label=lname,
559
+ lw=lw,
560
+ )
561
+ test = 1
562
+ ax1.plot(
563
+ j2[j1 == i2] + n,
564
+ tmp[i0, i1, j1 == i2],
565
+ color=color,
566
+ lw=lw,
567
+ )
568
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
569
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
570
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
571
+ tab2nx = tab2nx + ["%d" % (i2)]
572
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
573
+ n = n + j2[j1 == i2].shape[0] - 1
426
574
  else:
427
575
  for i0 in range(tmp.shape[0]):
428
- for i2 in range(j1.max()+1):
429
- if j2[j1==i2].shape[0]==1:
430
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2],'.', \
431
- color=color, lw=lw)
576
+ for i2 in range(j1.max() + 1):
577
+ if j2[j1 == i2].shape[0] == 1:
578
+ ax1.plot(
579
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], ".", color=color, lw=lw
580
+ )
432
581
  else:
433
582
  if legend and test is None:
434
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
435
- color=color, label=lname, lw=lw)
436
- test=1
437
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
438
- color=color, lw=lw)
439
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
440
- tabx=tabx+[k+n for k in j2[j1==i2]]
441
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
442
- tab2nx=tab2nx+['%d'%(i2)]
443
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
444
- n=n+j2[j1==i2].shape[0]-1
445
- plt.yscale('log')
446
- ax1.set_xlim(0,n+2)
583
+ ax1.plot(
584
+ j2[j1 == i2] + n,
585
+ tmp[i0, j1 == i2],
586
+ color=color,
587
+ label=lname,
588
+ lw=lw,
589
+ )
590
+ test = 1
591
+ ax1.plot(
592
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], color=color, lw=lw
593
+ )
594
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
595
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
596
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
597
+ tab2nx = tab2nx + ["%d" % (i2)]
598
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
599
+ n = n + j2[j1 == i2].shape[0] - 1
600
+ plt.yscale("log")
601
+ ax1.set_xlim(0, n + 2)
447
602
  ax1.set_xticks(tabx)
448
- ax1.set_xticklabels(tabnx,fontsize=6)
449
- ax1.set_xlabel(r"$j_{2}$ ",fontsize=6)
450
-
603
+ ax1.set_xticklabels(tabnx, fontsize=6)
604
+ ax1.set_xlabel(r"$j_{2}$ ", fontsize=6)
605
+
451
606
  # Move twinned axis ticks and label from top to bottom
452
607
  ax2.xaxis.set_ticks_position("bottom")
453
608
  ax2.xaxis.set_label_position("bottom")
@@ -455,7 +610,7 @@ class scat1D:
455
610
  # Offset the twin axis below the host
456
611
  ax2.spines["bottom"].set_position(("axes", -0.15))
457
612
 
458
- # Turn on the frame for the twin axis, but then hide all
613
+ # Turn on the frame for the twin axis, but then hide all
459
614
  # but the bottom spine
460
615
  ax2.set_frame_on(True)
461
616
  ax2.patch.set_visible(False)
@@ -463,68 +618,91 @@ class scat1D:
463
618
  for sp in ax2.spines.values():
464
619
  sp.set_visible(False)
465
620
  ax2.spines["bottom"].set_visible(True)
466
- ax2.set_xlim(0,n+2)
621
+ ax2.set_xlim(0, n + 2)
467
622
  ax2.set_xticks(tab2x)
468
- ax2.set_xticklabels(tab2nx,fontsize=6)
469
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
623
+ ax2.set_xticklabels(tab2nx, fontsize=6)
624
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
470
625
  ax1.legend(frameon=0)
471
-
472
- ax1=plt.subplot(2, 2, 4)
626
+
627
+ ax1 = plt.subplot(2, 2, 4)
473
628
  ax2 = ax1.twiny()
474
- n=0
475
- tmp=abs(self.get_np(self.S2L))
476
- lname=r'%s $S2_{2}$' % (name)
477
- ax1.set_ylabel(r'$S_{2}$ [L2 norm]')
478
- test=None
479
- tabx=[]
480
- tabnx=[]
481
- tab2x=[]
482
- tab2nx=[]
483
- if len(tmp.shape)==5:
629
+ n = 0
630
+ tmp = abs(self.get_np(self.S2L))
631
+ lname = r"%s $S2_{2}$" % (name)
632
+ ax1.set_ylabel(r"$S_{2}$ [L2 norm]")
633
+ test = None
634
+ tabx = []
635
+ tabnx = []
636
+ tab2x = []
637
+ tab2nx = []
638
+ if len(tmp.shape) == 5:
484
639
  for i0 in range(tmp.shape[0]):
485
640
  for i1 in range(tmp.shape[1]):
486
- for i2 in range(j1.max()+1):
487
- if j2[j1==i2].shape[0]==1:
488
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2],'.', \
489
- color=color, lw=lw)
641
+ for i2 in range(j1.max() + 1):
642
+ if j2[j1 == i2].shape[0] == 1:
643
+ ax1.plot(
644
+ j2[j1 == i2] + n,
645
+ tmp[i0, i1, j1 == i2],
646
+ ".",
647
+ color=color,
648
+ lw=lw,
649
+ )
490
650
  else:
491
651
  if legend and test is None:
492
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
493
- color=color, label=lname, lw=lw)
494
- test=1
495
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
496
- color=color, lw=lw)
497
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
498
- tabx=tabx+[k+n for k in j2[j1==i2]]
499
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
500
- tab2nx=tab2nx+['%d'%(i2)]
501
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
502
- n=n+j2[j1==i2].shape[0]-1
652
+ ax1.plot(
653
+ j2[j1 == i2] + n,
654
+ tmp[i0, i1, j1 == i2],
655
+ color=color,
656
+ label=lname,
657
+ lw=lw,
658
+ )
659
+ test = 1
660
+ ax1.plot(
661
+ j2[j1 == i2] + n,
662
+ tmp[i0, i1, j1 == i2],
663
+ color=color,
664
+ lw=lw,
665
+ )
666
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
667
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
668
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
669
+ tab2nx = tab2nx + ["%d" % (i2)]
670
+ ax1.axvline(
671
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
672
+ )
673
+ n = n + j2[j1 == i2].shape[0] - 1
503
674
  else:
504
675
  for i0 in range(tmp.shape[0]):
505
- for i2 in range(j1.max()+1):
506
- if j2[j1==i2].shape[0]==1:
507
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2],'.', \
508
- color=color, lw=lw)
676
+ for i2 in range(j1.max() + 1):
677
+ if j2[j1 == i2].shape[0] == 1:
678
+ ax1.plot(
679
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], ".", color=color, lw=lw
680
+ )
509
681
  else:
510
682
  if legend and test is None:
511
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
512
- color=color, label=lname, lw=lw)
513
- test=1
514
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
515
- color=color, lw=lw)
516
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
517
- tabx=tabx+[k+n for k in j2[j1==i2]]
518
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
519
- tab2nx=tab2nx+['%d'%(i2)]
520
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
521
- n=n+j2[j1==i2].shape[0]-1
522
- plt.yscale('log')
523
- ax1.set_xlim(-1,n+3)
683
+ ax1.plot(
684
+ j2[j1 == i2] + n,
685
+ tmp[i0, j1 == i2],
686
+ color=color,
687
+ label=lname,
688
+ lw=lw,
689
+ )
690
+ test = 1
691
+ ax1.plot(
692
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], color=color, lw=lw
693
+ )
694
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
695
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
696
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
697
+ tab2nx = tab2nx + ["%d" % (i2)]
698
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
699
+ n = n + j2[j1 == i2].shape[0] - 1
700
+ plt.yscale("log")
701
+ ax1.set_xlim(-1, n + 3)
524
702
  ax1.set_xticks(tabx)
525
- ax1.set_xticklabels(tabnx,fontsize=6)
526
- ax1.set_xlabel(r"$j_{2}$",fontsize=6)
527
-
703
+ ax1.set_xticklabels(tabnx, fontsize=6)
704
+ ax1.set_xlabel(r"$j_{2}$", fontsize=6)
705
+
528
706
  # Move twinned axis ticks and label from top to bottom
529
707
  ax2.xaxis.set_ticks_position("bottom")
530
708
  ax2.xaxis.set_label_position("bottom")
@@ -532,7 +710,7 @@ class scat1D:
532
710
  # Offset the twin axis below the host
533
711
  ax2.spines["bottom"].set_position(("axes", -0.15))
534
712
 
535
- # Turn on the frame for the twin axis, but then hide all
713
+ # Turn on the frame for the twin axis, but then hide all
536
714
  # but the bottom spine
537
715
  ax2.set_frame_on(True)
538
716
  ax2.patch.set_visible(False)
@@ -540,125 +718,182 @@ class scat1D:
540
718
  for sp in ax2.spines.values():
541
719
  sp.set_visible(False)
542
720
  ax2.spines["bottom"].set_visible(True)
543
- ax2.set_xlim(0,n+3)
721
+ ax2.set_xlim(0, n + 3)
544
722
  ax2.set_xticks(tab2x)
545
- ax2.set_xticklabels(tab2nx,fontsize=6)
546
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
723
+ ax2.set_xticklabels(tab2nx, fontsize=6)
724
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
547
725
  ax1.legend(frameon=0)
548
-
549
- def save(self,filename):
550
- outlist=[self.get_S0().numpy(), \
551
- self.get_S1().numpy(), \
552
- self.get_S2().numpy(), \
553
- self.get_S2L().numpy(), \
554
- self.get_P00().numpy(), \
555
- self.j1, \
556
- self.j2]
557
-
558
- myout=open("%s.pkl"%(filename),"wb")
559
- pickle.dump(outlist,myout)
726
+
727
+ def save(self, filename):
728
+ outlist = [
729
+ self.get_S0().numpy(),
730
+ self.get_S1().numpy(),
731
+ self.get_S2().numpy(),
732
+ self.get_S2L().numpy(),
733
+ self.get_P00().numpy(),
734
+ self.j1,
735
+ self.j2,
736
+ ]
737
+
738
+ myout = open("%s.pkl" % (filename), "wb")
739
+ pickle.dump(outlist, myout)
560
740
  myout.close()
561
741
 
562
-
563
- def read(self,filename):
564
-
565
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
742
+ def read(self, filename):
743
+
744
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
745
+
746
+ return scat1D(
747
+ outlist[4],
748
+ outlist[0],
749
+ outlist[1],
750
+ outlist[2],
751
+ outlist[3],
752
+ outlist[5],
753
+ outlist[6],
754
+ backend=bk.foscat_backend("numpy"),
755
+ )
566
756
 
567
- return scat1D(outlist[4],outlist[0],outlist[1],outlist[2],outlist[3],outlist[5],outlist[6],backend=bk.foscat_backend('numpy'))
568
-
569
- def get_np(self,x):
757
+ def get_np(self, x):
570
758
  if isinstance(x, np.ndarray):
571
759
  return x
572
760
  else:
573
761
  return x.numpy()
574
762
 
575
763
  def std(self):
576
- return np.sqrt(((abs(self.get_np(self.S0)).std())**2+ \
577
- (abs(self.get_np(self.S1)).std())**2+ \
578
- (abs(self.get_np(self.S2)).std())**2+ \
579
- (abs(self.get_np(self.S2L)).std())**2+ \
580
- (abs(self.get_np(self.P00)).std())**2)/4)
764
+ return np.sqrt(
765
+ (
766
+ (abs(self.get_np(self.S0)).std()) ** 2
767
+ + (abs(self.get_np(self.S1)).std()) ** 2
768
+ + (abs(self.get_np(self.S2)).std()) ** 2
769
+ + (abs(self.get_np(self.S2L)).std()) ** 2
770
+ + (abs(self.get_np(self.P00)).std()) ** 2
771
+ )
772
+ / 4
773
+ )
581
774
 
582
775
  def mean(self):
583
- return abs(self.get_np(self.S0).mean()+ \
584
- self.get_np(self.S1).mean()+ \
585
- self.get_np(self.S2).mean()+ \
586
- self.get_np(self.S2L).mean()+ \
587
- self.get_np(self.P00).mean())/3
588
-
589
-
776
+ return (
777
+ abs(
778
+ self.get_np(self.S0).mean()
779
+ + self.get_np(self.S1).mean()
780
+ + self.get_np(self.S2).mean()
781
+ + self.get_np(self.S2L).mean()
782
+ + self.get_np(self.P00).mean()
783
+ )
784
+ / 3
785
+ )
590
786
 
591
787
  # ---------------------------------------------−---------
592
- def cleanval(self,x):
593
- x=x.numpy()
594
- x[np.isfinite(x)==False]=np.median(x[np.isfinite(x)])
788
+ def cleanval(self, x):
789
+ x = x.numpy()
790
+ x[~np.isfinite(x)] = np.median(x[np.isfinite(x)])
595
791
  return x
596
792
 
597
793
  def filter_inf(self):
598
- S1 = self.cleanval(self.S1)
599
- S0 = self.cleanval(self.S0)
794
+ S1 = self.cleanval(self.S1)
795
+ S0 = self.cleanval(self.S0)
600
796
  P00 = self.cleanval(self.P00)
601
- S2 = self.cleanval(self.S2)
797
+ S2 = self.cleanval(self.S2)
602
798
  S2L = self.cleanval(self.S2L)
603
799
 
604
- return scat1D(P00,S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
800
+ return scat1D(P00, S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
605
801
 
606
802
  # ---------------------------------------------−---------
607
- def interp(self,nscale,extend=False,constant=False,threshold=1E30):
608
-
609
- if nscale+2>self.S1.shape[1]:
610
- print('Can not *interp* %d with a statistic described over %d'%(nscale,self.S1.shape[1]))
611
- return scat1D(self.P00,self.S0,self.S1,self.S2,self.S2L,self.j1,self.j2,backend=self.backend)
612
- if isinstance(self.S1,np.ndarray):
613
- s1=self.S1
614
- p0=self.P00
615
- s2=self.S2
616
- s2l=self.S2L
803
+ def interp(self, nscale, extend=False, constant=False, threshold=1e30):
804
+
805
+ if nscale + 2 > self.S1.shape[1]:
806
+ print(
807
+ "Can not *interp* %d with a statistic described over %d"
808
+ % (nscale, self.S1.shape[1])
809
+ )
810
+ return scat1D(
811
+ self.P00,
812
+ self.S0,
813
+ self.S1,
814
+ self.S2,
815
+ self.S2L,
816
+ self.j1,
817
+ self.j2,
818
+ backend=self.backend,
819
+ )
820
+ if isinstance(self.S1, np.ndarray):
821
+ s1 = self.S1
822
+ p0 = self.P00
823
+ s2 = self.S2
824
+ s2l = self.S2L
617
825
  else:
618
- s1=self.S1.numpy()
619
- p0=self.P00.numpy()
620
- s2=self.S2.numpy()
621
- s2l=self.S2L.numpy()
622
-
623
- if isinstance(threshold,scat1D):
624
- if isinstance(threshold.S1,np.ndarray):
625
- s1th=threshold.S1
626
- p0th=threshold.P00
627
- s2th=threshold.S2
628
- s2lth=threshold.S2L
826
+ s1 = self.S1.numpy()
827
+ p0 = self.P00.numpy()
828
+ s2 = self.S2.numpy()
829
+ s2l = self.S2L.numpy()
830
+
831
+ if isinstance(threshold, scat1D):
832
+ if isinstance(threshold.S1, np.ndarray):
833
+ s1th = threshold.S1
834
+ p0th = threshold.P00
835
+ s2th = threshold.S2
836
+ s2lth = threshold.S2L
629
837
  else:
630
- s1th=threshold.S1.numpy()
631
- p0th=threshold.P00.numpy()
632
- s2th=threshold.S2.numpy()
633
- s2lth=threshold.S2L.numpy()
838
+ s1th = threshold.S1.numpy()
839
+ p0th = threshold.P00.numpy()
840
+ s2th = threshold.S2.numpy()
841
+ s2lth = threshold.S2L.numpy()
634
842
  else:
635
- s1th=threshold+0*s1
636
- p0th=threshold+0*p0
637
- s2th=threshold+0*s2
638
- s2lth=threshold+0*s2l
843
+ s1th = threshold + 0 * s1
844
+ p0th = threshold + 0 * p0
845
+ s2th = threshold + 0 * s2
846
+ s2lth = threshold + 0 * s2l
639
847
 
640
848
  for k in range(nscale):
641
849
  if constant:
642
- s1[:,nscale-1-k,:]=s1[:,nscale-k,:]
643
- p0[:,nscale-1-k,:]=p0[:,nscale-k,:]
850
+ s1[:, nscale - 1 - k, :] = s1[:, nscale - k, :]
851
+ p0[:, nscale - 1 - k, :] = p0[:, nscale - k, :]
644
852
  else:
645
- idx=np.where((s1[:,nscale+1-k,:]>0)*(s1[:,nscale+2-k,:]>0)*(s1[:,nscale-k,:]<s1th[:,nscale-k,:]))
646
- if len(idx[0])>0:
647
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(s1[idx[0],nscale+1-k,idx[1]])-2*np.log(s1[idx[0],nscale+2-k,idx[1]]))
648
- idx=np.where((s1[:,nscale-k,:]>0)*(s1[:,nscale+1-k,:]>0)*(s1[:,nscale-1-k,:]<s1th[:,nscale-1-k,:]))
649
- if len(idx[0])>0:
650
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(s1[idx[0],nscale-k,idx[1]])-np.log(s1[idx[0],nscale+1-k,idx[1]]))
651
-
652
- idx=np.where((p0[:,nscale+1-k,:]>0)*(p0[:,nscale+2-k,:]>0)*(p0[:,nscale-k,:]<p0th[:,nscale-k,:]))
653
- if len(idx[0])>0:
654
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(p0[idx[0],nscale+1-k,idx[1]])-2*np.log(p0[idx[0],nscale+2-k,idx[1]]))
655
-
656
- idx=np.where((p0[:,nscale-k,:]>0)*(p0[:,nscale+1-k,:]>0)*(p0[:,nscale-1-k,:]<p0th[:,nscale-1-k,:]))
657
- if len(idx[0])>0:
658
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(p0[idx[0],nscale-k,idx[1]])-np.log(p0[idx[0],nscale+1-k,idx[1]]))
659
-
660
-
661
- j1,j2=self.get_j_idx()
853
+ idx = np.where(
854
+ (s1[:, nscale + 1 - k, :] > 0)
855
+ * (s1[:, nscale + 2 - k, :] > 0)
856
+ * (s1[:, nscale - k, :] < s1th[:, nscale - k, :])
857
+ )
858
+ if len(idx[0]) > 0:
859
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
860
+ 3 * np.log(s1[idx[0], nscale + 1 - k, idx[1]])
861
+ - 2 * np.log(s1[idx[0], nscale + 2 - k, idx[1]])
862
+ )
863
+ idx = np.where(
864
+ (s1[:, nscale - k, :] > 0)
865
+ * (s1[:, nscale + 1 - k, :] > 0)
866
+ * (s1[:, nscale - 1 - k, :] < s1th[:, nscale - 1 - k, :])
867
+ )
868
+ if len(idx[0]) > 0:
869
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
870
+ 2 * np.log(s1[idx[0], nscale - k, idx[1]])
871
+ - np.log(s1[idx[0], nscale + 1 - k, idx[1]])
872
+ )
873
+
874
+ idx = np.where(
875
+ (p0[:, nscale + 1 - k, :] > 0)
876
+ * (p0[:, nscale + 2 - k, :] > 0)
877
+ * (p0[:, nscale - k, :] < p0th[:, nscale - k, :])
878
+ )
879
+ if len(idx[0]) > 0:
880
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
881
+ 3 * np.log(p0[idx[0], nscale + 1 - k, idx[1]])
882
+ - 2 * np.log(p0[idx[0], nscale + 2 - k, idx[1]])
883
+ )
884
+
885
+ idx = np.where(
886
+ (p0[:, nscale - k, :] > 0)
887
+ * (p0[:, nscale + 1 - k, :] > 0)
888
+ * (p0[:, nscale - 1 - k, :] < p0th[:, nscale - 1 - k, :])
889
+ )
890
+ if len(idx[0]) > 0:
891
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
892
+ 2 * np.log(p0[idx[0], nscale - k, idx[1]])
893
+ - np.log(p0[idx[0], nscale + 1 - k, idx[1]])
894
+ )
895
+
896
+ j1, j2 = self.get_j_idx()
662
897
 
663
898
  for k in range(nscale):
664
899
 
@@ -671,460 +906,632 @@ class scat1D:
671
906
  s2l[:,i0]=np.exp(2*np.log(s2l[:,i1])-np.log(s2l[:,i2]))
672
907
  """
673
908
 
674
- for l in range(nscale-k):
675
- i0=np.where((j1==nscale-1-k-l)*(j2==nscale-1-k))[0]
676
- i1=np.where((j1==nscale-1-k-l)*(j2==nscale -k))[0]
677
- i2=np.where((j1==nscale-1-k-l)*(j2==nscale+1-k))[0]
678
- i3=np.where((j1==nscale-1-k-l)*(j2==nscale+2-k))[0]
679
-
909
+ for l_orient in range(nscale - k):
910
+ i0 = np.where((j1 == nscale - 1 - k - l_orient) * (j2 == nscale - 1 - k))[0]
911
+ i1 = np.where((j1 == nscale - 1 - k - l_orient) * (j2 == nscale - k))[0]
912
+ i2 = np.where((j1 == nscale - 1 - k - l_orient) * (j2 == nscale + 1 - k))[0]
913
+ i3 = np.where((j1 == nscale - 1 - k - l_orient) * (j2 == nscale + 2 - k))[0]
914
+
680
915
  if constant:
681
- s2[:,i0]=s2[:,i1]
682
- s2l[:,i0]=s2l[:,i1]
916
+ s2[:, i0] = s2[:, i1]
917
+ s2l[:, i0] = s2l[:, i1]
683
918
  else:
684
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
685
- if len(idx[0])>0:
686
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
687
-
688
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
689
- if len(idx[0])>0:
690
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
691
-
692
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
693
- if len(idx[0])>0:
694
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
695
-
696
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
697
- if len(idx[0])>0:
698
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
699
-
919
+ idx = np.where(
920
+ (s2[:, i2] > 0) * (s2[:, i3] > 0) * (s2[:, i2] < s2th[:, i2])
921
+ )
922
+ if len(idx[0]) > 0:
923
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
924
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
925
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
926
+ )
927
+
928
+ idx = np.where(
929
+ (s2[:, i1] > 0) * (s2[:, i2] > 0) * (s2[:, i1] < s2th[:, i1])
930
+ )
931
+ if len(idx[0]) > 0:
932
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
933
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
934
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
935
+ )
936
+
937
+ idx = np.where(
938
+ (s2l[:, i2] > 0)
939
+ * (s2l[:, i3] > 0)
940
+ * (s2l[:, i2] < s2lth[:, i2])
941
+ )
942
+ if len(idx[0]) > 0:
943
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
944
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
945
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
946
+ )
947
+
948
+ idx = np.where(
949
+ (s2l[:, i1] > 0)
950
+ * (s2l[:, i2] > 0)
951
+ * (s2l[:, i1] < s2lth[:, i1])
952
+ )
953
+ if len(idx[0]) > 0:
954
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
955
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
956
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
957
+ )
958
+
700
959
  if extend:
701
960
  for k in range(nscale):
702
- for l in range(1,nscale):
703
- i0=np.where((j1==2*nscale-1-k)*(j2==2*nscale-1-k-l))[0]
704
- i1=np.where((j1==2*nscale-1-k)*(j2==2*nscale -k-l))[0]
705
- i2=np.where((j1==2*nscale-1-k)*(j2==2*nscale+1-k-l))[0]
706
- i3=np.where((j1==2*nscale-1-k)*(j2==2*nscale+2-k-l))[0]
961
+ for l_orient in range(1, nscale):
962
+ i0 = np.where(
963
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale - 1 - k - l_orient)
964
+ )[0]
965
+ i1 = np.where(
966
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale - k - l_orient)
967
+ )[0]
968
+ i2 = np.where(
969
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale + 1 - k - l_orient)
970
+ )[0]
971
+ i3 = np.where(
972
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale + 2 - k - l_orient)
973
+ )[0]
707
974
  if constant:
708
- s2[:,i0]=s2[:,i1]
709
- s2l[:,i0]=s2l[:,i1]
975
+ s2[:, i0] = s2[:, i1]
976
+ s2l[:, i0] = s2l[:, i1]
710
977
  else:
711
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
712
- if len(idx[0])>0:
713
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
714
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
715
- if len(idx[0])>0:
716
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
717
-
718
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
719
- if len(idx[0])>0:
720
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
721
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
722
- if len(idx[0])>0:
723
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
724
-
725
- s1[np.isnan(s1)]=0.0
726
- p0[np.isnan(p0)]=0.0
727
- s2[np.isnan(s2)]=0.0
728
- s2l[np.isnan(s2l)]=0.0
729
-
730
- return scat1D(self.backend.constant(p0),self.S0,
731
- self.backend.constant(s1),
732
- self.backend.constant(s2),
733
- self.backend.constant(s2l),self.j1,self.j2,backend=self.backend)
978
+ idx = np.where(
979
+ (s2[:, i2] > 0)
980
+ * (s2[:, i3] > 0)
981
+ * (s2[:, i2] < s2th[:, i2])
982
+ )
983
+ if len(idx[0]) > 0:
984
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
985
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
986
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
987
+ )
988
+ idx = np.where(
989
+ (s2[:, i1] > 0)
990
+ * (s2[:, i2] > 0)
991
+ * (s2[:, i1] < s2th[:, i1])
992
+ )
993
+ if len(idx[0]) > 0:
994
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
995
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
996
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
997
+ )
998
+
999
+ idx = np.where(
1000
+ (s2l[:, i2] > 0)
1001
+ * (s2l[:, i3] > 0)
1002
+ * (s2l[:, i2] < s2lth[:, i2])
1003
+ )
1004
+ if len(idx[0]) > 0:
1005
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1006
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1007
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1008
+ )
1009
+ idx = np.where(
1010
+ (s2l[:, i1] > 0)
1011
+ * (s2l[:, i2] > 0)
1012
+ * (s2l[:, i1] < s2lth[:, i1])
1013
+ )
1014
+ if len(idx[0]) > 0:
1015
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1016
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1017
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1018
+ )
1019
+
1020
+ s1[np.isnan(s1)] = 0.0
1021
+ p0[np.isnan(p0)] = 0.0
1022
+ s2[np.isnan(s2)] = 0.0
1023
+ s2l[np.isnan(s2l)] = 0.0
1024
+
1025
+ return scat1D(
1026
+ self.backend.constant(p0),
1027
+ self.S0,
1028
+ self.backend.constant(s1),
1029
+ self.backend.constant(s2),
1030
+ self.backend.constant(s2l),
1031
+ self.j1,
1032
+ self.j2,
1033
+ backend=self.backend,
1034
+ )
734
1035
 
735
1036
  # ---------------------------------------------−---------
736
- def model(self,i__y,add=0,dx=3,dell=2,weigth=None,inverse=False):
1037
+ def model(self, i__y, add=0, dx=3, dell=2, weigth=None, inverse=False):
737
1038
 
738
- if i__y.shape[0]<dx+1:
739
- l__dx=i__y.shape[0]-1
1039
+ if i__y.shape[0] < dx + 1:
1040
+ l__dx = i__y.shape[0] - 1
740
1041
  else:
741
- l__dx=dx
1042
+ l__dx = dx
742
1043
 
743
- if i__y.shape[0]<dell:
744
- l__dell=0
1044
+ if i__y.shape[0] < dell:
1045
+ l__dell = 0
745
1046
  else:
746
- l__dell=dell
1047
+ l__dell = dell
747
1048
 
748
- if l__dx<2:
749
- res=np.zeros([i__y.shape[0]+add])
1049
+ if l__dx < 2:
1050
+ res = np.zeros([i__y.shape[0] + add])
750
1051
  if inverse:
751
- res[:-add]=i__y
1052
+ res[:-add] = i__y
752
1053
  else:
753
- res[add:]=i__y[0:]
1054
+ res[add:] = i__y[0:]
754
1055
  return res
755
1056
 
756
1057
  if weigth is None:
757
- w=2**(np.arange(l__dx))
1058
+ w = 2 ** (np.arange(l__dx))
758
1059
  else:
759
1060
  if not inverse:
760
- w=weigth[0:l__dx]
1061
+ w = weigth[0:l__dx]
761
1062
  else:
762
- w=weigth[-l__dx:]
1063
+ w = weigth[-l__dx:]
763
1064
 
764
- x=np.arange(l__dx)+1
1065
+ x = np.arange(l__dx) + 1
765
1066
  if not inverse:
766
- y=np.log(i__y[1:l__dx+1])
1067
+ y = np.log(i__y[1 : l__dx + 1])
767
1068
  else:
768
- y=np.log(i__y[-(l__dx+1):-1])
1069
+ y = np.log(i__y[-(l__dx + 1) : -1])
769
1070
 
770
- r=np.polyfit(x,y,1,w=w)
1071
+ r = np.polyfit(x, y, 1, w=w)
771
1072
 
772
1073
  if inverse:
773
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-1)+r[1])
774
- res[:-(l__dell+add)]=i__y[:-l__dell]
1074
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - 1) + r[1])
1075
+ res[: -(l__dell + add)] = i__y[:-l__dell]
775
1076
  else:
776
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-add)+r[1])
777
- res[l__dell+add:]=i__y[l__dell:]
1077
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - add) + r[1])
1078
+ res[l__dell + add :] = i__y[l__dell:]
778
1079
  return res
779
1080
 
780
- def findn(self,n):
781
- d=np.sqrt(1+8*n)
782
- return int((d-1)/2)
1081
+ def findn(self, n):
1082
+ d = np.sqrt(1 + 8 * n)
1083
+ return int((d - 1) / 2)
783
1084
 
784
- def findidx(self,s2):
785
- i1=np.zeros([s2.shape[1]],dtype='int')
786
- i2=np.zeros([s2.shape[1]],dtype='int')
787
- n=0
1085
+ def findidx(self, s2):
1086
+ i1 = np.zeros([s2.shape[1]], dtype="int")
1087
+ i2 = np.zeros([s2.shape[1]], dtype="int")
1088
+ n = 0
788
1089
  for k in range(self.findn(s2.shape[1])):
789
- i1[n:n+k+1]=np.arange(k+1)
790
- i2[n:n+k+1]=k
791
- n=n+k+1
792
- return i1,i2
793
-
794
- def extrapol_s2(self,add,lnorm=1):
795
- if lnorm==1:
796
- s2=self.S2.numpy()
797
- if lnorm==2:
798
- s2=self.S2L.numpy()
799
- i1,i2=self.findidx(s2)
800
-
801
- so2=np.zeros([s2.shape[0],(self.findn(s2.shape[1])+add)*(self.findn(s2.shape[1])+add+1)//2,s2.shape[2],s2.shape[3]])
802
- oi1,oi2=self.findidx(so2)
803
- for l in range(s2.shape[0]):
1090
+ i1[n : n + k + 1] = np.arange(k + 1)
1091
+ i2[n : n + k + 1] = k
1092
+ n = n + k + 1
1093
+ return i1, i2
1094
+
1095
+ def extrapol_s2(self, add, lnorm=1):
1096
+ if lnorm == 1:
1097
+ s2 = self.S2.numpy()
1098
+ if lnorm == 2:
1099
+ s2 = self.S2L.numpy()
1100
+ i1, i2 = self.findidx(s2)
1101
+
1102
+ so2 = np.zeros(
1103
+ [
1104
+ s2.shape[0],
1105
+ (self.findn(s2.shape[1]) + add)
1106
+ * (self.findn(s2.shape[1]) + add + 1)
1107
+ // 2,
1108
+ s2.shape[2],
1109
+ s2.shape[3],
1110
+ ]
1111
+ )
1112
+ oi1, oi2 = self.findidx(so2)
1113
+ for l_orient in range(s2.shape[0]):
804
1114
  for k in range(self.findn(s2.shape[1])):
805
1115
  for i in range(s2.shape[2]):
806
1116
  for j in range(s2.shape[3]):
807
- tmp=self.model(s2[l,i2==k,i,j],dx=4,dell=1,add=add,weigth=np.array([1,2,2,2]))
808
- tmp[np.isnan(tmp)]=0.0
809
- so2[l,oi2==k+add,i,j]=tmp
810
-
811
-
812
- for l in range(s2.shape[0]):
813
- for k in range(add+1,-1,-1):
814
- lidx=np.where(oi2-oi1==k)[0]
815
- lidx2=np.where(oi2-oi1==k+1)[0]
1117
+ tmp = self.model(
1118
+ s2[l_orient, i2 == k, i, j],
1119
+ dx=4,
1120
+ dell=1,
1121
+ add=add,
1122
+ weigth=np.array([1, 2, 2, 2]),
1123
+ )
1124
+ tmp[np.isnan(tmp)] = 0.0
1125
+ so2[l_orient, oi2 == k + add, i, j] = tmp
1126
+
1127
+ for l_orient in range(s2.shape[0]):
1128
+ for k in range(add + 1, -1, -1):
1129
+ lidx = np.where(oi2 - oi1 == k)[0]
1130
+ lidx2 = np.where(oi2 - oi1 == k + 1)[0]
816
1131
  for i in range(s2.shape[2]):
817
1132
  for j in range(s2.shape[3]):
818
- so2[l,lidx[0:add+2-k],i,j]=so2[l,lidx2[0:add+2-k],i,j]
1133
+ so2[l_orient, lidx[0 : add + 2 - k], i, j] = so2[
1134
+ l_orient, lidx2[0 : add + 2 - k], i, j
1135
+ ]
819
1136
 
820
- return(so2)
1137
+ return so2
821
1138
 
822
- def extrapol_s1(self,i_s1,add):
823
- s1=i_s1.numpy()
824
- so1=np.zeros([s1.shape[0],s1.shape[1]+add,s1.shape[2]])
1139
+ def extrapol_s1(self, i_s1, add):
1140
+ s1 = i_s1.numpy()
1141
+ so1 = np.zeros([s1.shape[0], s1.shape[1] + add, s1.shape[2]])
825
1142
  for k in range(s1.shape[0]):
826
1143
  for i in range(s1.shape[2]):
827
- so1[k,:,i]=self.model(s1[k,:,i],dx=4,dell=1,add=add)
828
- so1[k,np.isnan(so1[k,:,i]),i]=0.0
1144
+ so1[k, :, i] = self.model(s1[k, :, i], dx=4, dell=1, add=add)
1145
+ so1[k, np.isnan(so1[k, :, i]), i] = 0.0
829
1146
  return so1
830
1147
 
831
- def extrapol(self,add):
832
- return scat1D(self.extrapol_s1(self.P00,add), \
833
- self.S0, \
834
- self.extrapol_s1(self.S1,add), \
835
- self.extrapol_s2(add,lnorm=1), \
836
- self.extrapol_s2(add,lnorm=2),self.j1,self.j2,backend=self.backend)
837
-
838
-
839
-
840
-
841
-
1148
+ def extrapol(self, add):
1149
+ return scat1D(
1150
+ self.extrapol_s1(self.P00, add),
1151
+ self.S0,
1152
+ self.extrapol_s1(self.S1, add),
1153
+ self.extrapol_s2(add, lnorm=1),
1154
+ self.extrapol_s2(add, lnorm=2),
1155
+ self.j1,
1156
+ self.j2,
1157
+ backend=self.backend,
1158
+ )
1159
+
1160
+
842
1161
  class funct(FOC.FoCUS):
843
1162
 
844
- def fill(self,im,nullval=0):
845
- return self.fill_1d(im,nullval=nullval)
846
-
847
- def ud_grade(self,im,nout,axis=0):
848
- return self.ud_grade_1d(im,nout,axis=axis)
849
-
850
- def up_grade(self,im,nout,axis=0):
851
- return self.up_grade_1d(im,nout,axis=axis)
852
-
853
- def smooth(self,data,axis=0):
854
- return self.smooth_1d(data,axis=axis)
855
-
856
- def convol(self,data,axis=0):
857
- return self.convol_1d(data,axis=axis)
858
-
859
- def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,Add_R45=False,axis=0):
1163
+ def fill(self, im, nullval=0):
1164
+ return self.fill_1d(im, nullval=nullval)
1165
+
1166
+ def ud_grade(self, im, nout, axis=0):
1167
+ return self.ud_grade_1d(im, nout, axis=axis)
1168
+
1169
+ def up_grade(self, im, nout, axis=0):
1170
+ return self.up_grade_1d(im, nout, axis=axis)
1171
+
1172
+ def smooth(self, data, axis=0):
1173
+ return self.smooth_1d(data, axis=axis)
1174
+
1175
+ def convol(self, data, axis=0):
1176
+ return self.convol_1d(data, axis=axis)
1177
+
1178
+ def eval(
1179
+ self,
1180
+ image1,
1181
+ image2=None,
1182
+ mask=None,
1183
+ Auto=True,
1184
+ s0_off=1e-6,
1185
+ Add_R45=False,
1186
+ axis=0,
1187
+ ):
860
1188
  # Check input consistency
861
1189
  if mask is not None:
862
- if len(image1.shape)==1:
863
- if image1.shape[0]!=mask.shape[1]:
864
- print('The mask should have the same size than the input timeline to eval Scattering')
1190
+ if len(image1.shape) == 1:
1191
+ if image1.shape[0] != mask.shape[1]:
1192
+ print(
1193
+ "The mask should have the same size than the input timeline to eval Scattering"
1194
+ )
865
1195
  return None
866
1196
  else:
867
- if image1.shape[1]!=mask.shape[1]:
868
- print('The mask should have the same size than the input timeline to eval Scattering')
1197
+ if image1.shape[1] != mask.shape[1]:
1198
+ print(
1199
+ "The mask should have the same size than the input timeline to eval Scattering"
1200
+ )
869
1201
  return None
870
-
1202
+
871
1203
  ### AUTO OR CROSS
872
1204
  cross = False
873
1205
  if image2 is not None:
874
1206
  cross = True
875
- all_cross=not Auto
876
- else:
877
- all_cross=False
878
-
1207
+
879
1208
  # determine jmax and nside corresponding to the input map
880
1209
  im_shape = image1.shape
881
1210
 
882
- nside=im_shape[len(image1.shape)-1]
883
- npix=nside
884
-
885
- jmax=int(np.log(nside)/np.log(2)) #-self.OSTEP
1211
+ nside = im_shape[len(image1.shape) - 1]
1212
+ npix = nside
1213
+
1214
+ jmax = int(np.log(nside) / np.log(2)) # -self.OSTEP
886
1215
 
887
1216
  ### LOCAL VARIABLES (IMAGES and MASK)
888
1217
  # Check if image1 is [Npix] or [Nbatch,Npix]
889
- if len(image1.shape)==1:
1218
+ if len(image1.shape) == 1:
890
1219
  # image1 is [Nbatch, Npix]
891
- I1 = self.backend.bk_cast(self.backend.bk_expand_dims(image1,0)) # Local image1 [Nbatch, Npix]
1220
+ I1 = self.backend.bk_cast(
1221
+ self.backend.bk_expand_dims(image1, 0)
1222
+ ) # Local image1 [Nbatch, Npix]
892
1223
  if cross:
893
- I2 = self.backend.bk_cast(self.backend.bk_expand_dims(image2,0)) # Local image2 [Nbatch, Npix]
1224
+ I2 = self.backend.bk_cast(
1225
+ self.backend.bk_expand_dims(image2, 0)
1226
+ ) # Local image2 [Nbatch, Npix]
894
1227
  else:
895
- I1=self.backend.bk_cast(image1)
1228
+ I1 = self.backend.bk_cast(image1)
896
1229
  if cross:
897
- I2=self.backend.bk_cast(image2)
898
-
1230
+ I2 = self.backend.bk_cast(image2)
1231
+
899
1232
  # self.mask is [Nmask, Npix]
900
-
1233
+
901
1234
  if mask is None:
902
1235
  vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
903
1236
  else:
904
1237
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
905
1238
 
906
- if self.KERNELSZ>3:
1239
+ if self.KERNELSZ > 3:
907
1240
  # if the kernel size is bigger than 3 increase the binning before smoothing
908
- l_image1=self.up_grade_1d(I1,nside*2,axis=axis+1)
909
- vmask=self.up_grade_1d(vmask,nside*2,axis=1)
910
-
1241
+ l_image1 = self.up_grade_1d(I1, nside * 2, axis=axis + 1)
1242
+ vmask = self.up_grade_1d(vmask, nside * 2, axis=1)
1243
+
911
1244
  if cross:
912
- l_image2=self.up_grade_1d(I2,nside*2,axis=axis+1)
1245
+ l_image2 = self.up_grade_1d(I2, nside * 2, axis=axis + 1)
913
1246
  else:
914
- l_image1=I1
1247
+ l_image1 = I1
915
1248
  if cross:
916
- l_image2=I2
917
- if len(image1.shape)==1:
918
- s0 = self.backend.bk_reduce_sum(l_image1*vmask,axis=axis+1)
919
- if cross and Auto==False:
920
- s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(l_image2*vmask,axis=axis)])
1249
+ l_image2 = I2
1250
+ if len(image1.shape) == 1:
1251
+ s0 = self.backend.bk_reduce_sum(l_image1 * vmask, axis=axis + 1)
1252
+ if cross and not Auto:
1253
+ s0 = self.backend.bk_concat(
1254
+ [s0, self.backend.bk_reduce_sum(l_image2 * vmask, axis=axis)]
1255
+ )
921
1256
  else:
922
- lmask=self.backend.bk_expand_dims(vmask,0)
923
- lim=self.backend.bk_expand_dims(l_image1,1)
924
- s0 = self.backend.bk_reduce_sum(lim*lmask,axis=axis+2)
925
- if cross and Auto==False:
926
- lim=self.backend.bk_expand_dims(l_image2,1)
927
- s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(lim*lmask,axis=axis+2)])
928
-
929
-
930
-
931
- s1=None
932
- s2=None
933
- s2l=None
934
- p00=None
935
- s2j1=None
936
- s2j2=None
937
-
938
- l2_image=None
939
- l2_image_imag=None
1257
+ lmask = self.backend.bk_expand_dims(vmask, 0)
1258
+ lim = self.backend.bk_expand_dims(l_image1, 1)
1259
+ s0 = self.backend.bk_reduce_sum(lim * lmask, axis=axis + 2)
1260
+ if cross and not Auto:
1261
+ lim = self.backend.bk_expand_dims(l_image2, 1)
1262
+ s0 = self.backend.bk_concat(
1263
+ [s0, self.backend.bk_reduce_sum(lim * lmask, axis=axis + 2)]
1264
+ )
1265
+
1266
+ s1 = None
1267
+ s2 = None
1268
+ s2l = None
1269
+ p00 = None
1270
+ s2j1 = None
1271
+ s2j2 = None
1272
+
1273
+ l2_image = None
940
1274
 
941
1275
  for j1 in range(jmax):
942
- if j1<jmax-self.OSTEP: # stop to add scales
1276
+ if j1 < jmax - self.OSTEP: # stop to add scales
943
1277
  # Convol image along the axis defined by 'axis' using the wavelet defined at
944
1278
  # the foscat initialisation
945
- #c_image_real is [....,Npix_j1,....,Norient]
946
- c_image1=self.convol_1d(l_image1,axis=axis+1)
1279
+ # c_image_real is [....,Npix_j1,....,Norient]
1280
+ c_image1 = self.convol_1d(l_image1, axis=axis + 1)
947
1281
  if cross:
948
- c_image2=self.convol_1d(l_image2,axis=axis+1)
1282
+ c_image2 = self.convol_1d(l_image2, axis=axis + 1)
949
1283
  else:
950
- c_image2=c_image1
1284
+ c_image2 = c_image1
951
1285
 
952
1286
  # Compute (a+ib)*(a+ib)* the last c_image column is the real and imaginary part
953
- conj=c_image1*self.backend.bk_conjugate(c_image2)
954
-
1287
+ conj = c_image1 * self.backend.bk_conjugate(c_image2)
1288
+
955
1289
  if Auto:
956
- conj=self.backend.bk_real(conj)
1290
+ conj = self.backend.bk_real(conj)
957
1291
 
958
1292
  # Compute l_p00 [....,....,Nmask,1]
959
- l_p00 = self.backend.bk_expand_dims(self.backend.bk_reduce_sum(conj*vmask,axis=1),-1)
1293
+ l_p00 = self.backend.bk_expand_dims(
1294
+ self.backend.bk_reduce_sum(conj * vmask, axis=1), -1
1295
+ )
960
1296
 
961
- conj=self.backend.bk_L1(conj)
1297
+ conj = self.backend.bk_L1(conj)
962
1298
 
963
1299
  # Compute l_s1 [....,....,Nmask,1]
964
- l_s1 = self.backend.bk_expand_dims(self.backend.bk_reduce_sum(conj*vmask,axis=1),-1)
965
-
966
- # Concat S1,P00 [....,....,Nmask,j1]
1300
+ l_s1 = self.backend.bk_expand_dims(
1301
+ self.backend.bk_reduce_sum(conj * vmask, axis=1), -1
1302
+ )
1303
+
1304
+ # Concat S1,P00 [....,....,Nmask,j1]
967
1305
  if s1 is None:
968
- s1=l_s1
969
- p00=l_p00
1306
+ s1 = l_s1
1307
+ p00 = l_p00
970
1308
  else:
971
- s1=self.backend.bk_concat([s1,l_s1],axis=-1)
972
- p00=self.backend.bk_concat([p00,l_p00],axis=-1)
1309
+ s1 = self.backend.bk_concat([s1, l_s1], axis=-1)
1310
+ p00 = self.backend.bk_concat([p00, l_p00], axis=-1)
973
1311
 
974
1312
  # Concat l2_image [....,Npix_j1,....,j1]
975
1313
  if l2_image is None:
976
- l2_image=self.backend.bk_expand_dims(conj,axis=1)
1314
+ l2_image = self.backend.bk_expand_dims(conj, axis=1)
977
1315
  else:
978
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=1),l2_image],axis=1)
1316
+ l2_image = self.backend.bk_concat(
1317
+ [self.backend.bk_expand_dims(conj, axis=1), l2_image], axis=1
1318
+ )
979
1319
 
980
1320
  # Convol l2_image [....,Npix_j1,....,j1,Norient,Norient]
981
- c2_image=self.convol_1d(self.backend.bk_relu(l2_image),axis=axis+2)
1321
+ c2_image = self.convol_1d(self.backend.bk_relu(l2_image), axis=axis + 2)
982
1322
 
983
- conj2p=c2_image*self.backend.bk_conjugate(c2_image)
984
- conj2pl1=self.backend.bk_L1(conj2p)
1323
+ conj2p = c2_image * self.backend.bk_conjugate(c2_image)
1324
+ conj2pl1 = self.backend.bk_L1(conj2p)
985
1325
 
986
1326
  if Auto:
987
- conj2p=self.backend.bk_real(conj2p)
988
- conj2pl1=self.backend.bk_real(conj2pl1)
1327
+ conj2p = self.backend.bk_real(conj2p)
1328
+ conj2pl1 = self.backend.bk_real(conj2pl1)
989
1329
 
990
- c2_image=self.convol_1d(self.backend.bk_relu(-l2_image),axis=axis+2)
1330
+ c2_image = self.convol_1d(self.backend.bk_relu(-l2_image), axis=axis + 2)
991
1331
 
992
- conj2m=c2_image*self.backend.bk_conjugate(c2_image)
993
- conj2ml1=self.backend.bk_L1(conj2m)
1332
+ conj2m = c2_image * self.backend.bk_conjugate(c2_image)
1333
+ conj2ml1 = self.backend.bk_L1(conj2m)
994
1334
 
995
1335
  if Auto:
996
- conj2m=self.backend.bk_real(conj2m)
997
- conj2ml1=self.backend.bk_real(conj2ml1)
998
-
1336
+ conj2m = self.backend.bk_real(conj2m)
1337
+ conj2ml1 = self.backend.bk_real(conj2ml1)
1338
+
999
1339
  # Convol l_s2 [....,....,Nmask,j1]
1000
- l_s2 = self.backend.bk_reduce_sum((conj2p-conj2m)*self.backend.bk_expand_dims(vmask,1),axis=axis+2)
1001
- l_s2l1 = self.backend.bk_reduce_sum((conj2pl1-conj2ml1)*self.backend.bk_expand_dims(vmask,1),axis=axis+2)
1340
+ l_s2 = self.backend.bk_reduce_sum(
1341
+ (conj2p - conj2m) * self.backend.bk_expand_dims(vmask, 1), axis=axis + 2
1342
+ )
1343
+ l_s2l1 = self.backend.bk_reduce_sum(
1344
+ (conj2pl1 - conj2ml1) * self.backend.bk_expand_dims(vmask, 1),
1345
+ axis=axis + 2,
1346
+ )
1002
1347
 
1003
1348
  # Concat l_s2 [....,....,Nmask,j1*(j1+1)/2,Norient,Norient]
1004
1349
  if s2 is None:
1005
- s2l=l_s2
1006
- s2=l_s2l1
1007
- s2j1=np.arange(l_s2.shape[axis],dtype='int')
1008
- s2j2=j1*np.ones(l_s2.shape[axis],dtype='int')
1350
+ s2l = l_s2
1351
+ s2 = l_s2l1
1352
+ s2j1 = np.arange(l_s2.shape[axis], dtype="int")
1353
+ s2j2 = j1 * np.ones(l_s2.shape[axis], dtype="int")
1009
1354
  else:
1010
- s2=self.backend.bk_concat([s2,l_s2l1],axis=-1)
1011
- s2l=self.backend.bk_concat([s2l,l_s2],axis=-1)
1012
- s2j1=np.concatenate([s2j1,np.arange(l_s2.shape[axis+1],dtype='int')],0)
1013
- s2j2=np.concatenate([s2j2,j1*np.ones(l_s2.shape[axis+1],dtype='int')],0)
1014
-
1015
- if j1!=jmax-1:
1016
- # Rescale vmask [Nmask,Npix_j1//4]
1017
- vmask = self.smooth_1d(vmask,axis=1)
1018
- vmask = self.ud_grade_1d(vmask,vmask.shape[1]//2,axis=1)
1355
+ s2 = self.backend.bk_concat([s2, l_s2l1], axis=-1)
1356
+ s2l = self.backend.bk_concat([s2l, l_s2], axis=-1)
1357
+ s2j1 = np.concatenate(
1358
+ [s2j1, np.arange(l_s2.shape[axis + 1], dtype="int")], 0
1359
+ )
1360
+ s2j2 = np.concatenate(
1361
+ [s2j2, j1 * np.ones(l_s2.shape[axis + 1], dtype="int")], 0
1362
+ )
1363
+
1364
+ if j1 != jmax - 1:
1365
+ # Rescale vmask [Nmask,Npix_j1//4]
1366
+ vmask = self.smooth_1d(vmask, axis=1)
1367
+ vmask = self.ud_grade_1d(vmask, vmask.shape[1] // 2, axis=1)
1019
1368
  if self.mask_thres is not None:
1020
- vmask = self.backend.bk_threshold(vmask,self.mask_thres)
1021
-
1022
- # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1023
- l2_image = self.smooth_1d(l2_image,axis=axis+2)
1024
- l2_image = self.ud_grade_1d(l2_image,l2_image.shape[axis+2]//2,axis=axis+2)
1025
-
1026
- # Rescale l_image [....,Npix_j1//4,....]
1027
- l_image1 = self.smooth_1d(l_image1,axis=axis+1)
1028
- l_image1 = self.ud_grade_1d(l_image1,l_image1.shape[axis+1]//2,axis=axis+1)
1369
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
1370
+
1371
+ # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1372
+ l2_image = self.smooth_1d(l2_image, axis=axis + 2)
1373
+ l2_image = self.ud_grade_1d(
1374
+ l2_image, l2_image.shape[axis + 2] // 2, axis=axis + 2
1375
+ )
1376
+
1377
+ # Rescale l_image [....,Npix_j1//4,....]
1378
+ l_image1 = self.smooth_1d(l_image1, axis=axis + 1)
1379
+ l_image1 = self.ud_grade_1d(
1380
+ l_image1, l_image1.shape[axis + 1] // 2, axis=axis + 1
1381
+ )
1029
1382
  if cross:
1030
- l_image2 = self.smooth_1d(l_image2,axis=axis+2)
1031
- l_image2 = self.ud_grade_1d(l_image2,l_image2.shape[axis+2]//2,axis=axis+2)
1383
+ l_image2 = self.smooth_1d(l_image2, axis=axis + 2)
1384
+ l_image2 = self.ud_grade_1d(
1385
+ l_image2, l_image2.shape[axis + 2] // 2, axis=axis + 2
1386
+ )
1032
1387
 
1033
- return(scat1D(p00,s0,s1,s2,s2l,s2j1,s2j2,cross=cross,backend=self.backend))
1388
+ return scat1D(
1389
+ p00, s0, s1, s2, s2l, s2j1, s2j2, cross=cross, backend=self.backend
1390
+ )
1034
1391
 
1035
- def square(self,x):
1392
+ def square(self, x):
1036
1393
  # the abs make the complex value usable for reduce_sum or mean
1037
- return scat1D(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1038
- self.backend.bk_square(self.backend.bk_abs(x.S0)),
1039
- self.backend.bk_square(self.backend.bk_abs(x.S1)),
1040
- self.backend.bk_square(self.backend.bk_abs(x.S2)),
1041
- self.backend.bk_square(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1042
-
1043
- def sqrt(self,x):
1394
+ return scat1D(
1395
+ self.backend.bk_square(self.backend.bk_abs(x.P00)),
1396
+ self.backend.bk_square(self.backend.bk_abs(x.S0)),
1397
+ self.backend.bk_square(self.backend.bk_abs(x.S1)),
1398
+ self.backend.bk_square(self.backend.bk_abs(x.S2)),
1399
+ self.backend.bk_square(self.backend.bk_abs(x.S2L)),
1400
+ x.j1,
1401
+ x.j2,
1402
+ backend=self.backend,
1403
+ )
1404
+
1405
+ def sqrt(self, x):
1044
1406
  # the abs make the complex value usable for reduce_sum or mean
1045
- return scat1D(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1046
- self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1047
- self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1048
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1049
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
1050
-
1051
- def reduce_mean(self,x,axis=None):
1407
+ return scat1D(
1408
+ self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1409
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1410
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1411
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1412
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),
1413
+ x.j1,
1414
+ x.j2,
1415
+ backend=self.backend,
1416
+ )
1417
+
1418
+ def reduce_mean(self, x, axis=None):
1052
1419
  if axis is None:
1053
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))+ \
1054
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))+ \
1055
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))+ \
1056
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))+ \
1057
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1058
-
1059
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1060
- np.array(list(x.S0.shape)).prod()+ \
1061
- np.array(list(x.S1.shape)).prod()+ \
1062
- np.array(list(x.S2.shape)).prod()
1063
-
1064
- return tmp/ntmp
1420
+ tmp = (
1421
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))
1422
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))
1423
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))
1424
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))
1425
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1426
+ )
1427
+
1428
+ ntmp = (
1429
+ np.array(list(x.P00.shape)).prod()
1430
+ + np.array(list(x.S0.shape)).prod()
1431
+ + np.array(list(x.S1.shape)).prod()
1432
+ + np.array(list(x.S2.shape)).prod()
1433
+ )
1434
+
1435
+ return tmp / ntmp
1065
1436
  else:
1066
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00,axis=axis))+ \
1067
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0,axis=axis))+ \
1068
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1,axis=axis))+ \
1069
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2,axis=axis))+ \
1070
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L,axis=axis))
1071
-
1072
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1073
- np.array(list(x.S0.shape)).prod()+ \
1074
- np.array(list(x.S1.shape)).prod()+ \
1075
- np.array(list(x.S2.shape)).prod()+ \
1076
- np.array(list(x.S2L.shape)).prod()
1077
-
1078
- return tmp/ntmp
1079
-
1080
- def reduce_sum(self,x,axis=None):
1437
+ tmp = (
1438
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00, axis=axis))
1439
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0, axis=axis))
1440
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1, axis=axis))
1441
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2, axis=axis))
1442
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L, axis=axis))
1443
+ )
1444
+
1445
+ ntmp = (
1446
+ np.array(list(x.P00.shape)).prod()
1447
+ + np.array(list(x.S0.shape)).prod()
1448
+ + np.array(list(x.S1.shape)).prod()
1449
+ + np.array(list(x.S2.shape)).prod()
1450
+ + np.array(list(x.S2L.shape)).prod()
1451
+ )
1452
+
1453
+ return tmp / ntmp
1454
+
1455
+ def reduce_sum(self, x, axis=None):
1081
1456
  if axis is None:
1082
- return self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))+ \
1083
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))+ \
1084
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))+ \
1085
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))+ \
1086
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1457
+ return (
1458
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))
1459
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
1460
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
1461
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
1462
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1463
+ )
1087
1464
  else:
1088
- return scat1D(self.backend.bk_reduce_sum(x.P00,axis=axis),
1089
- self.backend.bk_reduce_sum(x.S0,axis=axis),
1090
- self.backend.bk_reduce_sum(x.S1,axis=axis),
1091
- self.backend.bk_reduce_sum(x.S2,axis=axis),
1092
- self.backend.bk_reduce_sum(x.S2L,axis=axis),x.j1,x.j2,backend=self.backend)
1093
-
1094
- def ldiff(self,sig,x):
1095
- return scat1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1096
- x.domult(sig.S0,x.S0)*x.domult(sig.S0,x.S0),
1097
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1098
- x.domult(sig.S2,x.S2)*x.domult(sig.S2,x.S2),
1099
- x.domult(sig.S2L,x.S2L)*x.domult(sig.S2L,x.S2L),x.j1,x.j2,backend=self.backend)
1100
-
1101
- def log(self,x):
1102
- return scat1D(self.backend.bk_log(x.P00),
1103
- self.backend.bk_log(x.S0),
1104
- self.backend.bk_log(x.S1),
1105
- self.backend.bk_log(x.S2),
1106
- self.backend.bk_log(x.S2L),x.j1,x.j2,backend=self.backend)
1107
- def abs(self,x):
1108
- return scat1D(self.backend.bk_abs(x.P00),
1109
- self.backend.bk_abs(x.S0),
1110
- self.backend.bk_abs(x.S1),
1111
- self.backend.bk_abs(x.S2),
1112
- self.backend.bk_abs(x.S2L),x.j1,x.j2,backend=self.backend)
1113
- def inv(self,x):
1114
- return scat1D(1/(x.P00),1/(x.S0),1/(x.S1),1/(x.S2),1/(x.S2L),x.j1,x.j2,backend=self.backend)
1465
+ return scat1D(
1466
+ self.backend.bk_reduce_sum(x.P00, axis=axis),
1467
+ self.backend.bk_reduce_sum(x.S0, axis=axis),
1468
+ self.backend.bk_reduce_sum(x.S1, axis=axis),
1469
+ self.backend.bk_reduce_sum(x.S2, axis=axis),
1470
+ self.backend.bk_reduce_sum(x.S2L, axis=axis),
1471
+ x.j1,
1472
+ x.j2,
1473
+ backend=self.backend,
1474
+ )
1475
+
1476
+ def ldiff(self, sig, x):
1477
+ return scat1D(
1478
+ x.domult(sig.P00, x.P00) * x.domult(sig.P00, x.P00),
1479
+ x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
1480
+ x.domult(sig.S1, x.S1) * x.domult(sig.S1, x.S1),
1481
+ x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
1482
+ x.domult(sig.S2L, x.S2L) * x.domult(sig.S2L, x.S2L),
1483
+ x.j1,
1484
+ x.j2,
1485
+ backend=self.backend,
1486
+ )
1487
+
1488
+ def log(self, x):
1489
+ return scat1D(
1490
+ self.backend.bk_log(x.P00),
1491
+ self.backend.bk_log(x.S0),
1492
+ self.backend.bk_log(x.S1),
1493
+ self.backend.bk_log(x.S2),
1494
+ self.backend.bk_log(x.S2L),
1495
+ x.j1,
1496
+ x.j2,
1497
+ backend=self.backend,
1498
+ )
1499
+
1500
+ def abs(self, x):
1501
+ return scat1D(
1502
+ self.backend.bk_abs(x.P00),
1503
+ self.backend.bk_abs(x.S0),
1504
+ self.backend.bk_abs(x.S1),
1505
+ self.backend.bk_abs(x.S2),
1506
+ self.backend.bk_abs(x.S2L),
1507
+ x.j1,
1508
+ x.j2,
1509
+ backend=self.backend,
1510
+ )
1511
+
1512
+ def inv(self, x):
1513
+ return scat1D(
1514
+ 1 / (x.P00),
1515
+ 1 / (x.S0),
1516
+ 1 / (x.S1),
1517
+ 1 / (x.S2),
1518
+ 1 / (x.S2L),
1519
+ x.j1,
1520
+ x.j2,
1521
+ backend=self.backend,
1522
+ )
1115
1523
 
1116
1524
  def one(self):
1117
- return scat1D(1.0,1.0,1.0,1.0,1.0,[0],[0],backend=self.backend)
1525
+ return scat1D(1.0, 1.0, 1.0, 1.0, 1.0, [0], [0], backend=self.backend)
1118
1526
 
1119
- def eval_comp_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1527
+ def eval_comp_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1120
1528
 
1121
- res=self.eval_fast(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1122
- return res.P00,res.S0,res.S1,res.S2,res.S2L,res.j1,res.j2
1529
+ res = self.eval_fast(image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off)
1530
+ return res.P00, res.S0, res.S1, res.S2, res.S2L, res.j1, res.j2
1123
1531
 
1124
1532
  @tf_function
1125
- def eval_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1126
- p0,s0,s1,s2,s2l,j1,j2=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1127
- return scat1D(p0,s0,s1,s2,s2l,j1,j2,backend=self.backend)
1128
-
1129
-
1130
-
1533
+ def eval_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1534
+ p0, s0, s1, s2, s2l, j1, j2 = self.eval_comp_fast(
1535
+ image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off
1536
+ )
1537
+ return scat1D(p0, s0, s1, s2, s2l, j1, j2, backend=self.backend)