foscat 3.0.9__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat1D.py CHANGED
@@ -1,396 +1,608 @@
1
- import foscat.FoCUS as FOC
2
- import numpy as np
3
- import tensorflow as tf
4
1
  import pickle
2
+ import sys
3
+
4
+ import numpy as np
5
+
5
6
  import foscat.backend as bk
6
-
7
+ import foscat.FoCUS as FOC
8
+
9
+ # Vérifier si TensorFlow est importé et défini
10
+ tf_defined = "tensorflow" in sys.modules
11
+
12
+ if tf_defined:
13
+ import tensorflow as tf
14
+
15
+ tf_function = (
16
+ tf.function
17
+ ) # Facultatif : si vous voulez utiliser TensorFlow dans ce script
18
+ else:
19
+
20
+ def tf_function(func):
21
+ return func
22
+
23
+
7
24
  def read(filename):
8
- thescat=scat1D(1,1,1,1,1,[0],[0])
25
+ thescat = scat1D(1, 1, 1, 1, 1, [0], [0])
9
26
  return thescat.read(filename)
10
-
27
+
28
+
11
29
  class scat1D:
12
- def __init__(self,p00,s0,s1,s2,s2l,j1,j2,cross=False,backend=None):
13
- self.P00=p00
14
- self.S0=s0
15
- self.S1=s1
16
- self.S2=s2
17
- self.S2L=s2l
18
- self.j1=j1
19
- self.j2=j2
20
- self.cross=cross
21
- self.backend=backend
22
-
30
+ def __init__(self, p00, s0, s1, s2, s2l, j1, j2, cross=False, backend=None):
31
+ self.P00 = p00
32
+ self.S0 = s0
33
+ self.S1 = s1
34
+ self.S2 = s2
35
+ self.S2L = s2l
36
+ self.j1 = j1
37
+ self.j2 = j2
38
+ self.cross = cross
39
+ self.backend = backend
40
+
41
+ # ---------------------------------------------−---------
42
+ def build_flat(self, table):
43
+ shape = table.shape
44
+ ndata = 1
45
+ for k in range(1, len(shape)):
46
+ ndata = ndata * shape[k]
47
+ return self.backend.bk_reshape(table, [shape[0], ndata])
48
+
49
+ # ---------------------------------------------−---------
50
+ def flatten(self, S2L=True, P00=True):
51
+ tmp = [self.build_flat(self.S0), self.build_flat(self.S1)]
52
+
53
+ if P00:
54
+ tmp = tmp + [self.build_flat(self.P00)]
55
+
56
+ tmp = tmp + [self.build_flat(self.S2)]
57
+
58
+ if S2L:
59
+ tmp = tmp + [self.build_flat(self.S2L)]
60
+
61
+ if isinstance(self.P00, np.ndarray):
62
+ return np.concatenate(tmp, 1)
63
+ else:
64
+ return self.backend.bk_concat(tmp, 1)
65
+
66
+ # ---------------------------------------------−---------
67
+ def flatten_name(self, S2L=True, P00=True):
68
+
69
+ tmp = ["S0"]
70
+
71
+ tmp = tmp + ["S1_%d" % (k) for k in range(self.S1.shape[-1])]
72
+
73
+ if P00:
74
+ tmp = tmp + ["P00_%d" % (k) for k in range(self.P00.shape[-1])]
75
+
76
+ j1, j2 = self.get_j_idx()
77
+
78
+ tmp = tmp + ["S2_%d-%d" % (j1[k], j2[k]) for k in range(self.S2.shape[-1])]
79
+
80
+ if S2L:
81
+ tmp = tmp + ["S2L_%d-%d" % (j1[k], j2[k]) for k in range(self.S2.shape[-1])]
82
+
83
+ return tmp
84
+
23
85
  def get_j_idx(self):
24
- return self.j1,self.j2
25
-
86
+ return self.j1, self.j2
87
+
26
88
  def get_S0(self):
27
- return(self.S0)
89
+ return self.S0
28
90
 
29
91
  def get_S1(self):
30
- return(self.S1)
31
-
92
+ return self.S1
93
+
32
94
  def get_S2(self):
33
- return(self.S2)
95
+ return self.S2
34
96
 
35
97
  def get_S2L(self):
36
- return(self.S2L)
98
+ return self.S2L
37
99
 
38
100
  def get_P00(self):
39
- return(self.P00)
101
+ return self.P00
40
102
 
41
103
  def reset_P00(self):
42
- self.P00=0*self.P00
104
+ self.P00 = 0 * self.P00
43
105
 
44
106
  def constant(self):
45
- return scat1D(self.backend.constant(self.P00 ), \
46
- self.backend.constant(self.S0 ), \
47
- self.backend.constant(self.S1 ), \
48
- self.backend.constant(self.S2 ), \
49
- self.backend.constant(self.S2L ), \
50
- self.j1 , \
51
- self.j2 ,backend=self.backend)
52
-
53
- def domult(self,x,y):
54
- if x.dtype==y.dtype:
55
- return x*y
107
+ return scat1D(
108
+ self.backend.constant(self.P00),
109
+ self.backend.constant(self.S0),
110
+ self.backend.constant(self.S1),
111
+ self.backend.constant(self.S2),
112
+ self.backend.constant(self.S2L),
113
+ self.j1,
114
+ self.j2,
115
+ backend=self.backend,
116
+ )
117
+
118
+ def domult(self, x, y):
119
+ if x.dtype == y.dtype:
120
+ return x * y
121
+
56
122
  if self.backend.bk_is_complex(x):
57
-
58
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
123
+
124
+ return self.backend.bk_complex(
125
+ self.backend.bk_real(x) * y, self.backend.bk_imag(x) * y
126
+ )
59
127
  else:
60
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
61
-
62
- def dodiv(self,x,y):
63
- if x.dtype==y.dtype:
64
- return x/y
128
+ return self.backend.bk_complex(
129
+ self.backend.bk_real(y) * x, self.backend.bk_imag(y) * x
130
+ )
131
+
132
+ def dodiv(self, x, y):
133
+ if x.dtype == y.dtype:
134
+ return x / y
65
135
  if self.backend.bk_is_complex(x):
66
-
67
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
136
+
137
+ return self.backend.bk_complex(
138
+ self.backend.bk_real(x) / y, self.backend.bk_imag(x) / y
139
+ )
68
140
  else:
69
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
70
-
71
- def domin(self,x,y):
72
- if x.dtype==y.dtype:
73
- return x-y
141
+ return self.backend.bk_complex(
142
+ x / self.backend.bk_real(y), x / self.backend.bk_imag(y)
143
+ )
144
+
145
+ def domin(self, x, y):
146
+ if x.dtype == y.dtype:
147
+ return x - y
148
+
74
149
  if self.backend.bk_is_complex(x):
75
-
76
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
150
+
151
+ return self.backend.bk_complex(
152
+ self.backend.bk_real(x) - y, self.backend.bk_imag(x) - y
153
+ )
77
154
  else:
78
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
79
-
80
- def doadd(self,x,y):
81
- if x.dtype==y.dtype:
82
- return x+y
155
+ return self.backend.bk_complex(
156
+ x - self.backend.bk_real(y), x - self.backend.bk_imag(y)
157
+ )
158
+
159
+ def doadd(self, x, y):
160
+ if x.dtype == y.dtype:
161
+ return x + y
162
+
83
163
  if self.backend.bk_is_complex(x):
84
-
85
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
164
+
165
+ return self.backend.bk_complex(
166
+ self.backend.bk_real(x) + y, self.backend.bk_imag(x) + y
167
+ )
86
168
  else:
87
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
88
-
89
- def relu(self):
90
-
91
- return scat1D(self.backend.bk_relu(self.P00), \
92
- self.backend.bk_relu(self.S0), \
93
- self.backend.bk_relu(self.S1), \
94
- self.backend.bk_relu(self.S2), \
95
- self.backend.bk_relu(self.S2L), \
96
- self.j1,self.j2,backend=self.backend)
97
-
98
- def __add__(self,other):
99
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
100
- isinstance(other, bool) or isinstance(other, scat1D)
101
-
169
+ return self.backend.bk_complex(
170
+ x + self.backend.bk_real(y), x + self.backend.bk_imag(y)
171
+ )
172
+
173
+ def __add__(self, other):
174
+ assert (
175
+ isinstance(other, float)
176
+ or isinstance(other, np.float32)
177
+ or isinstance(other, int)
178
+ or isinstance(other, bool)
179
+ or isinstance(other, scat1D)
180
+ )
181
+
102
182
  if isinstance(other, scat1D):
103
- return scat1D(self.doadd(self.P00,other.P00), \
104
- self.doadd(self.S0, other.S0), \
105
- self.doadd(self.S1, other.S1), \
106
- self.doadd(self.S2, other.S2), \
107
- self.doadd(self.S2L, other.S2L), \
108
- self.j1,self.j2,backend=self.backend)
183
+ return scat1D(
184
+ self.doadd(self.P00, other.P00),
185
+ self.doadd(self.S0, other.S0),
186
+ self.doadd(self.S1, other.S1),
187
+ self.doadd(self.S2, other.S2),
188
+ self.doadd(self.S2L, other.S2L),
189
+ self.j1,
190
+ self.j2,
191
+ backend=self.backend,
192
+ )
109
193
  else:
110
- return scat1D((self.P00+ other), \
111
- (self.S0+ other), \
112
- (self.S1+ other), \
113
- (self.S2+ other), \
114
- (self.S2L+ other), \
115
- self.j1,self.j2,backend=self.backend)
116
-
117
- def toreal(self,value):
194
+ return scat1D(
195
+ (self.P00 + other),
196
+ (self.S0 + other),
197
+ (self.S1 + other),
198
+ (self.S2 + other),
199
+ (self.S2L + other),
200
+ self.j1,
201
+ self.j2,
202
+ backend=self.backend,
203
+ )
204
+
205
+ def toreal(self, value):
118
206
  if value is None:
119
207
  return None
120
-
208
+
121
209
  return self.backend.bk_real(value)
122
210
 
123
- def addcomplex(self,value,amp):
211
+ def addcomplex(self, value, amp):
124
212
  if value is None:
125
213
  return None
126
-
127
- return self.backend.bk_complex(value,amp*value)
128
-
129
- def add_complex(self,amp):
130
- return scat1D(self.addcomplex(self.P00,amp), \
131
- self.addcomplex(self.S0,amp), \
132
- self.addcomplex(self.S1,amp), \
133
- self.addcomplex(self.S2,amp), \
134
- self.addcomplex(self.S2L,amp), \
135
- self.j1,self.j2,backend=self.backend)
136
-
214
+
215
+ return self.backend.bk_complex(value, amp * value)
216
+
217
+ def add_complex(self, amp):
218
+ return scat1D(
219
+ self.addcomplex(self.P00, amp),
220
+ self.addcomplex(self.S0, amp),
221
+ self.addcomplex(self.S1, amp),
222
+ self.addcomplex(self.S2, amp),
223
+ self.addcomplex(self.S2L, amp),
224
+ self.j1,
225
+ self.j2,
226
+ backend=self.backend,
227
+ )
228
+
137
229
  def real(self):
138
- return scat1D(self.toreal(self.P00), \
139
- self.toreal(self.S0), \
140
- self.toreal(self.S1), \
141
- self.toreal(self.S2), \
142
- self.toreal(self.S2L), \
143
- self.j1,self.j2,backend=self.backend)
144
-
145
- def __radd__(self,other):
230
+ return scat1D(
231
+ self.toreal(self.P00),
232
+ self.toreal(self.S0),
233
+ self.toreal(self.S1),
234
+ self.toreal(self.S2),
235
+ self.toreal(self.S2L),
236
+ self.j1,
237
+ self.j2,
238
+ backend=self.backend,
239
+ )
240
+
241
+ def __radd__(self, other):
146
242
  return self.__add__(other)
147
243
 
148
- def __truediv__(self,other):
149
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
150
- isinstance(other, bool) or isinstance(other, scat1D)
151
-
244
+ def __truediv__(self, other):
245
+ assert (
246
+ isinstance(other, float)
247
+ or isinstance(other, np.float32)
248
+ or isinstance(other, int)
249
+ or isinstance(other, bool)
250
+ or isinstance(other, scat1D)
251
+ )
252
+
152
253
  if isinstance(other, scat1D):
153
- return scat1D(self.dodiv(self.P00, other.P00), \
154
- self.dodiv(self.S0, other.S0), \
155
- self.dodiv(self.S1, other.S1), \
156
- self.dodiv(self.S2, other.S2), \
157
- self.dodiv(self.S2L, other.S2L), \
158
- self.j1,self.j2,backend=self.backend)
254
+ return scat1D(
255
+ self.dodiv(self.P00, other.P00),
256
+ self.dodiv(self.S0, other.S0),
257
+ self.dodiv(self.S1, other.S1),
258
+ self.dodiv(self.S2, other.S2),
259
+ self.dodiv(self.S2L, other.S2L),
260
+ self.j1,
261
+ self.j2,
262
+ backend=self.backend,
263
+ )
159
264
  else:
160
- return scat1D((self.P00/ other), \
161
- (self.S0/ other), \
162
- (self.S1/ other), \
163
- (self.S2/ other), \
164
- (self.S2L/ other), \
165
- self.j1,self.j2,backend=self.backend)
166
-
167
-
168
- def __rtruediv__(self,other):
169
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
170
- isinstance(other, bool) or isinstance(other, scat1D)
171
-
265
+ return scat1D(
266
+ (self.P00 / other),
267
+ (self.S0 / other),
268
+ (self.S1 / other),
269
+ (self.S2 / other),
270
+ (self.S2L / other),
271
+ self.j1,
272
+ self.j2,
273
+ backend=self.backend,
274
+ )
275
+
276
+ def __rtruediv__(self, other):
277
+ assert (
278
+ isinstance(other, float)
279
+ or isinstance(other, np.float32)
280
+ or isinstance(other, int)
281
+ or isinstance(other, bool)
282
+ or isinstance(other, scat1D)
283
+ )
284
+
172
285
  if isinstance(other, scat1D):
173
- return scat1D(self.dodiv(other.P00, self.P00), \
174
- self.dodiv(other.S0 , self.S0), \
175
- self.dodiv(other.S1 , self.S1), \
176
- self.dodiv(other.S2 , self.S2), \
177
- self.dodiv(other.S2L , self.S2L), \
178
- self.j1,self.j2,backend=self.backend)
286
+ return scat1D(
287
+ self.dodiv(other.P00, self.P00),
288
+ self.dodiv(other.S0, self.S0),
289
+ self.dodiv(other.S1, self.S1),
290
+ self.dodiv(other.S2, self.S2),
291
+ self.dodiv(other.S2L, self.S2L),
292
+ self.j1,
293
+ self.j2,
294
+ backend=self.backend,
295
+ )
179
296
  else:
180
- return scat1D((other/ self.P00), \
181
- (other / self.S0), \
182
- (other / self.S1), \
183
- (other / self.S2), \
184
- (other / self.S2L), \
185
- self.j1,self.j2,backend=self.backend)
186
-
187
- def __sub__(self,other):
188
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
189
- isinstance(other, bool) or isinstance(other, scat1D)
190
-
297
+ return scat1D(
298
+ (other / self.P00),
299
+ (other / self.S0),
300
+ (other / self.S1),
301
+ (other / self.S2),
302
+ (other / self.S2L),
303
+ self.j1,
304
+ self.j2,
305
+ backend=self.backend,
306
+ )
307
+
308
+ def __sub__(self, other):
309
+ assert (
310
+ isinstance(other, float)
311
+ or isinstance(other, np.float32)
312
+ or isinstance(other, int)
313
+ or isinstance(other, bool)
314
+ or isinstance(other, scat1D)
315
+ )
316
+
191
317
  if isinstance(other, scat1D):
192
- return scat1D(self.domin(self.P00, other.P00), \
193
- self.domin(self.S0, other.S0), \
194
- self.domin(self.S1, other.S1), \
195
- self.domin(self.S2, other.S2), \
196
- self.domin(self.S2L, other.S2L), \
197
- self.j1,self.j2,backend=self.backend)
318
+ return scat1D(
319
+ self.domin(self.P00, other.P00),
320
+ self.domin(self.S0, other.S0),
321
+ self.domin(self.S1, other.S1),
322
+ self.domin(self.S2, other.S2),
323
+ self.domin(self.S2L, other.S2L),
324
+ self.j1,
325
+ self.j2,
326
+ backend=self.backend,
327
+ )
198
328
  else:
199
- return scat1D((self.P00- other), \
200
- (self.S0- other), \
201
- (self.S1- other), \
202
- (self.S2- other), \
203
- (self.S2L- other), \
204
- self.j1,self.j2,backend=self.backend)
205
-
206
- def __rsub__(self,other):
207
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
208
- isinstance(other, bool) or isinstance(other, scat1D)
209
-
329
+ return scat1D(
330
+ (self.P00 - other),
331
+ (self.S0 - other),
332
+ (self.S1 - other),
333
+ (self.S2 - other),
334
+ (self.S2L - other),
335
+ self.j1,
336
+ self.j2,
337
+ backend=self.backend,
338
+ )
339
+
340
+ def __rsub__(self, other):
341
+ assert (
342
+ isinstance(other, float)
343
+ or isinstance(other, np.float32)
344
+ or isinstance(other, int)
345
+ or isinstance(other, bool)
346
+ or isinstance(other, scat1D)
347
+ )
348
+
210
349
  if isinstance(other, scat1D):
211
- return scat1D(self.domin(other.P00,self.P00), \
212
- self.domin(other.S0, self.S0), \
213
- self.domin(other.S1, self.S1), \
214
- self.domin(other.S2, self.S2), \
215
- self.domin(other.S2L, self.S2L), \
216
- self.j1,self.j2,backend=self.backend)
350
+ return scat1D(
351
+ self.domin(other.P00, self.P00),
352
+ self.domin(other.S0, self.S0),
353
+ self.domin(other.S1, self.S1),
354
+ self.domin(other.S2, self.S2),
355
+ self.domin(other.S2L, self.S2L),
356
+ self.j1,
357
+ self.j2,
358
+ backend=self.backend,
359
+ )
217
360
  else:
218
- return scat1D((other-self.P00), \
219
- (other-self.S0), \
220
- (other-self.S1), \
221
- (other-self.S2), \
222
- (other-self.S2L), \
223
- self.j1,self.j2,backend=self.backend)
224
-
225
- def __mul__(self,other):
226
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
227
- isinstance(other, bool) or isinstance(other, scat1D)
228
-
361
+ return scat1D(
362
+ (other - self.P00),
363
+ (other - self.S0),
364
+ (other - self.S1),
365
+ (other - self.S2),
366
+ (other - self.S2L),
367
+ self.j1,
368
+ self.j2,
369
+ backend=self.backend,
370
+ )
371
+
372
+ def __mul__(self, other):
373
+ assert (
374
+ isinstance(other, float)
375
+ or isinstance(other, np.float32)
376
+ or isinstance(other, int)
377
+ or isinstance(other, bool)
378
+ or isinstance(other, scat1D)
379
+ )
380
+
229
381
  if isinstance(other, scat1D):
230
- return scat1D(self.domult(self.P00, other.P00), \
231
- self.domult(self.S0, other.S0), \
232
- self.domult(self.S1, other.S1), \
233
- self.domult(self.S2, other.S2), \
234
- self.domult(self.S2L, other.S2L), \
235
- self.j1,self.j2,backend=self.backend)
382
+ return scat1D(
383
+ self.domult(self.P00, other.P00),
384
+ self.domult(self.S0, other.S0),
385
+ self.domult(self.S1, other.S1),
386
+ self.domult(self.S2, other.S2),
387
+ self.domult(self.S2L, other.S2L),
388
+ self.j1,
389
+ self.j2,
390
+ backend=self.backend,
391
+ )
236
392
  else:
237
- return scat1D((self.P00* other), \
238
- (self.S0* other), \
239
- (self.S1* other), \
240
- (self.S2* other), \
241
- (self.S2L* other), \
242
- self.j1,self.j2,backend=self.backend)
393
+ return scat1D(
394
+ (self.P00 * other),
395
+ (self.S0 * other),
396
+ (self.S1 * other),
397
+ (self.S2 * other),
398
+ (self.S2L * other),
399
+ self.j1,
400
+ self.j2,
401
+ backend=self.backend,
402
+ )
403
+
243
404
  def relu(self):
244
- return scat1D(self.backend.bk_relu(self.P00),
245
- self.backend.bk_relu(self.S0),
246
- self.backend.bk_relu(self.S1),
247
- self.backend.bk_relu(self.S2),
248
- self.backend.bk_relu(self.S2L), \
249
- self.j1,self.j2,backend=self.backend)
250
-
251
-
252
- def __rmul__(self,other):
253
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
254
- isinstance(other, bool) or isinstance(other, scat1D)
255
-
405
+ return scat1D(
406
+ self.backend.bk_relu(self.P00),
407
+ self.backend.bk_relu(self.S0),
408
+ self.backend.bk_relu(self.S1),
409
+ self.backend.bk_relu(self.S2),
410
+ self.backend.bk_relu(self.S2L),
411
+ self.j1,
412
+ self.j2,
413
+ backend=self.backend,
414
+ )
415
+
416
+ def __rmul__(self, other):
417
+ assert (
418
+ isinstance(other, float)
419
+ or isinstance(other, np.float32)
420
+ or isinstance(other, int)
421
+ or isinstance(other, bool)
422
+ or isinstance(other, scat1D)
423
+ )
424
+
256
425
  if isinstance(other, scat1D):
257
- return scat1D(self.domult(self.P00, other.P00), \
258
- self.domult(self.S0, other.S0), \
259
- self.domult(self.S1, other.S1), \
260
- self.domult(self.S2, other.S2), \
261
- self.domult(self.S2L, other.S2L), \
262
- self.j1,self.j2,backend=self.backend)
426
+ return scat1D(
427
+ self.domult(self.P00, other.P00),
428
+ self.domult(self.S0, other.S0),
429
+ self.domult(self.S1, other.S1),
430
+ self.domult(self.S2, other.S2),
431
+ self.domult(self.S2L, other.S2L),
432
+ self.j1,
433
+ self.j2,
434
+ backend=self.backend,
435
+ )
263
436
  else:
264
- return scat1D((self.P00* other), \
265
- (self.S0* other), \
266
- (self.S1* other), \
267
- (self.S2* other), \
268
- (self.S2L* other), \
269
- self.j1,self.j2,backend=self.backend)
270
-
271
- def l1_abs(self,x):
272
- y=self.get_np(x)
437
+ return scat1D(
438
+ (self.P00 * other),
439
+ (self.S0 * other),
440
+ (self.S1 * other),
441
+ (self.S2 * other),
442
+ (self.S2L * other),
443
+ self.j1,
444
+ self.j2,
445
+ backend=self.backend,
446
+ )
447
+
448
+ def l1_abs(self, x):
449
+ y = self.get_np(x)
273
450
  if self.backend.bk_is_complex(y):
274
- tmp=y.real*y.real+y.imag*y.imag
275
- tmp=np.sign(tmp)*np.sqrt(np.fabs(tmp))
276
- y=tmp
277
-
278
- return(y)
279
-
280
- def plot(self,name=None,hold=True,color='blue',lw=1,legend=True):
451
+ tmp = y.real * y.real + y.imag * y.imag
452
+ tmp = np.sign(tmp) * np.sqrt(np.fabs(tmp))
453
+ y = tmp
454
+
455
+ return y
456
+
457
+ def plot(self, name=None, hold=True, color="blue", lw=1, legend=True):
281
458
 
282
459
  import matplotlib.pyplot as plt
283
460
 
284
- j1,j2=self.get_j_idx()
285
-
461
+ j1, j2 = self.get_j_idx()
462
+
286
463
  if name is None:
287
- name=''
464
+ name = ""
288
465
 
289
466
  if hold:
290
- plt.figure(figsize=(16,8))
291
-
292
- test=None
467
+ plt.figure(figsize=(16, 8))
468
+
469
+ test = None
293
470
  plt.subplot(2, 2, 1)
294
- tmp=abs(self.get_np(self.S1))
295
- if len(tmp.shape)==4:
471
+ tmp = abs(self.get_np(self.S1))
472
+ if len(tmp.shape) == 4:
296
473
  for i1 in range(tmp.shape[0]):
297
474
  for i2 in range(tmp.shape[1]):
298
475
  if test is None:
299
- test=1
300
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
476
+ test = 1
477
+ plt.plot(
478
+ tmp[i1, i2, :],
479
+ color=color,
480
+ label=r"%s $S_{1}$" % (name),
481
+ lw=lw,
482
+ )
301
483
  else:
302
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
484
+ plt.plot(tmp[i1, i2, :], color=color, lw=lw)
303
485
  else:
304
486
  for i1 in range(tmp.shape[0]):
305
487
  if test is None:
306
- test=1
307
- plt.plot(tmp[i1,:],color=color, label=r'%s $S_{1}$' % (name), lw=lw)
488
+ test = 1
489
+ plt.plot(
490
+ tmp[i1, :], color=color, label=r"%s $S_{1}$" % (name), lw=lw
491
+ )
308
492
  else:
309
- plt.plot(tmp[i1,:],color=color, lw=lw)
310
- plt.yscale('log')
311
- plt.ylabel('S1')
312
- plt.xlabel(r'$j_{1}$')
493
+ plt.plot(tmp[i1, :], color=color, lw=lw)
494
+ plt.yscale("log")
495
+ plt.ylabel("S1")
496
+ plt.xlabel(r"$j_{1}$")
313
497
  plt.legend()
314
498
 
315
- test=None
499
+ test = None
316
500
  plt.subplot(2, 2, 2)
317
- tmp=abs(self.get_np(self.P00))
318
- if len(tmp.shape)==4:
501
+ tmp = abs(self.get_np(self.P00))
502
+ if len(tmp.shape) == 4:
319
503
  for i1 in range(tmp.shape[0]):
320
504
  for i2 in range(tmp.shape[0]):
321
505
  if test is None:
322
- test=1
323
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
506
+ test = 1
507
+ plt.plot(
508
+ tmp[i1, i2, :],
509
+ color=color,
510
+ label=r"%s $P_{00}$" % (name),
511
+ lw=lw,
512
+ )
324
513
  else:
325
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
514
+ plt.plot(tmp[i1, i2, :], color=color, lw=lw)
326
515
  else:
327
516
  for i1 in range(tmp.shape[0]):
328
517
  if test is None:
329
- test=1
330
- plt.plot(tmp[i1,:],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
518
+ test = 1
519
+ plt.plot(
520
+ tmp[i1, :], color=color, label=r"%s $P_{00}$" % (name), lw=lw
521
+ )
331
522
  else:
332
- plt.plot(tmp[i1,:],color=color, lw=lw)
333
- plt.yscale('log')
334
- plt.ylabel('P00')
335
- plt.xlabel(r'$j_{1}$')
523
+ plt.plot(tmp[i1, :], color=color, lw=lw)
524
+ plt.yscale("log")
525
+ plt.ylabel("P00")
526
+ plt.xlabel(r"$j_{1}$")
336
527
  plt.legend()
337
-
338
- ax1=plt.subplot(2, 2, 3)
528
+
529
+ ax1 = plt.subplot(2, 2, 3)
339
530
  ax2 = ax1.twiny()
340
- n=0
341
- tmp=abs(self.get_np(self.S2))
342
- lname=r'%s $S_{2}$' % (name)
343
- ax1.set_ylabel(r'$S_{2}$ [L1 norm]')
344
- test=None
345
- tabx=[]
346
- tabnx=[]
347
- tab2x=[]
348
- tab2nx=[]
349
- if len(tmp.shape)==5:
531
+ n = 0
532
+ tmp = abs(self.get_np(self.S2))
533
+ lname = r"%s $S_{2}$" % (name)
534
+ ax1.set_ylabel(r"$S_{2}$ [L1 norm]")
535
+ test = None
536
+ tabx = []
537
+ tabnx = []
538
+ tab2x = []
539
+ tab2nx = []
540
+ if len(tmp.shape) == 5:
350
541
  for i0 in range(tmp.shape[0]):
351
542
  for i1 in range(tmp.shape[1]):
352
- for i2 in range(j1.max()+1):
353
- if j2[j1==i2].shape[0]==1:
354
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2],'.', \
355
- color=color, lw=lw)
543
+ for i2 in range(j1.max() + 1):
544
+ if j2[j1 == i2].shape[0] == 1:
545
+ ax1.plot(
546
+ j2[j1 == i2] + n,
547
+ tmp[i0, i1, j1 == i2],
548
+ ".",
549
+ color=color,
550
+ lw=lw,
551
+ )
356
552
  else:
357
553
  if legend and test is None:
358
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
359
- color=color, label=lname, lw=lw)
360
- test=1
361
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
362
- color=color, lw=lw)
363
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
364
- tabx=tabx+[k+n for k in j2[j1==i2]]
365
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
366
- tab2nx=tab2nx+['%d'%(i2)]
367
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
368
- n=n+j2[j1==i2].shape[0]-1
554
+ ax1.plot(
555
+ j2[j1 == i2] + n,
556
+ tmp[i0, i1, j1 == i2],
557
+ color=color,
558
+ label=lname,
559
+ lw=lw,
560
+ )
561
+ test = 1
562
+ ax1.plot(
563
+ j2[j1 == i2] + n,
564
+ tmp[i0, i1, j1 == i2],
565
+ color=color,
566
+ lw=lw,
567
+ )
568
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
569
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
570
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
571
+ tab2nx = tab2nx + ["%d" % (i2)]
572
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
573
+ n = n + j2[j1 == i2].shape[0] - 1
369
574
  else:
370
575
  for i0 in range(tmp.shape[0]):
371
- for i2 in range(j1.max()+1):
372
- if j2[j1==i2].shape[0]==1:
373
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2],'.', \
374
- color=color, lw=lw)
576
+ for i2 in range(j1.max() + 1):
577
+ if j2[j1 == i2].shape[0] == 1:
578
+ ax1.plot(
579
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], ".", color=color, lw=lw
580
+ )
375
581
  else:
376
582
  if legend and test is None:
377
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
378
- color=color, label=lname, lw=lw)
379
- test=1
380
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
381
- color=color, lw=lw)
382
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
383
- tabx=tabx+[k+n for k in j2[j1==i2]]
384
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
385
- tab2nx=tab2nx+['%d'%(i2)]
386
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
387
- n=n+j2[j1==i2].shape[0]-1
388
- plt.yscale('log')
389
- ax1.set_xlim(0,n+2)
583
+ ax1.plot(
584
+ j2[j1 == i2] + n,
585
+ tmp[i0, j1 == i2],
586
+ color=color,
587
+ label=lname,
588
+ lw=lw,
589
+ )
590
+ test = 1
591
+ ax1.plot(
592
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], color=color, lw=lw
593
+ )
594
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
595
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
596
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
597
+ tab2nx = tab2nx + ["%d" % (i2)]
598
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
599
+ n = n + j2[j1 == i2].shape[0] - 1
600
+ plt.yscale("log")
601
+ ax1.set_xlim(0, n + 2)
390
602
  ax1.set_xticks(tabx)
391
- ax1.set_xticklabels(tabnx,fontsize=6)
392
- ax1.set_xlabel(r"$j_{2}$ ",fontsize=6)
393
-
603
+ ax1.set_xticklabels(tabnx, fontsize=6)
604
+ ax1.set_xlabel(r"$j_{2}$ ", fontsize=6)
605
+
394
606
  # Move twinned axis ticks and label from top to bottom
395
607
  ax2.xaxis.set_ticks_position("bottom")
396
608
  ax2.xaxis.set_label_position("bottom")
@@ -398,7 +610,7 @@ class scat1D:
398
610
  # Offset the twin axis below the host
399
611
  ax2.spines["bottom"].set_position(("axes", -0.15))
400
612
 
401
- # Turn on the frame for the twin axis, but then hide all
613
+ # Turn on the frame for the twin axis, but then hide all
402
614
  # but the bottom spine
403
615
  ax2.set_frame_on(True)
404
616
  ax2.patch.set_visible(False)
@@ -406,68 +618,91 @@ class scat1D:
406
618
  for sp in ax2.spines.values():
407
619
  sp.set_visible(False)
408
620
  ax2.spines["bottom"].set_visible(True)
409
- ax2.set_xlim(0,n+2)
621
+ ax2.set_xlim(0, n + 2)
410
622
  ax2.set_xticks(tab2x)
411
- ax2.set_xticklabels(tab2nx,fontsize=6)
412
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
623
+ ax2.set_xticklabels(tab2nx, fontsize=6)
624
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
413
625
  ax1.legend(frameon=0)
414
-
415
- ax1=plt.subplot(2, 2, 4)
626
+
627
+ ax1 = plt.subplot(2, 2, 4)
416
628
  ax2 = ax1.twiny()
417
- n=0
418
- tmp=abs(self.get_np(self.S2L))
419
- lname=r'%s $S2_{2}$' % (name)
420
- ax1.set_ylabel(r'$S_{2}$ [L2 norm]')
421
- test=None
422
- tabx=[]
423
- tabnx=[]
424
- tab2x=[]
425
- tab2nx=[]
426
- if len(tmp.shape)==5:
629
+ n = 0
630
+ tmp = abs(self.get_np(self.S2L))
631
+ lname = r"%s $S2_{2}$" % (name)
632
+ ax1.set_ylabel(r"$S_{2}$ [L2 norm]")
633
+ test = None
634
+ tabx = []
635
+ tabnx = []
636
+ tab2x = []
637
+ tab2nx = []
638
+ if len(tmp.shape) == 5:
427
639
  for i0 in range(tmp.shape[0]):
428
640
  for i1 in range(tmp.shape[1]):
429
- for i2 in range(j1.max()+1):
430
- if j2[j1==i2].shape[0]==1:
431
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2],'.', \
432
- color=color, lw=lw)
641
+ for i2 in range(j1.max() + 1):
642
+ if j2[j1 == i2].shape[0] == 1:
643
+ ax1.plot(
644
+ j2[j1 == i2] + n,
645
+ tmp[i0, i1, j1 == i2],
646
+ ".",
647
+ color=color,
648
+ lw=lw,
649
+ )
433
650
  else:
434
651
  if legend and test is None:
435
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
436
- color=color, label=lname, lw=lw)
437
- test=1
438
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
439
- color=color, lw=lw)
440
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
441
- tabx=tabx+[k+n for k in j2[j1==i2]]
442
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
443
- tab2nx=tab2nx+['%d'%(i2)]
444
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
445
- n=n+j2[j1==i2].shape[0]-1
652
+ ax1.plot(
653
+ j2[j1 == i2] + n,
654
+ tmp[i0, i1, j1 == i2],
655
+ color=color,
656
+ label=lname,
657
+ lw=lw,
658
+ )
659
+ test = 1
660
+ ax1.plot(
661
+ j2[j1 == i2] + n,
662
+ tmp[i0, i1, j1 == i2],
663
+ color=color,
664
+ lw=lw,
665
+ )
666
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
667
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
668
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
669
+ tab2nx = tab2nx + ["%d" % (i2)]
670
+ ax1.axvline(
671
+ (j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
672
+ )
673
+ n = n + j2[j1 == i2].shape[0] - 1
446
674
  else:
447
675
  for i0 in range(tmp.shape[0]):
448
- for i2 in range(j1.max()+1):
449
- if j2[j1==i2].shape[0]==1:
450
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2],'.', \
451
- color=color, lw=lw)
676
+ for i2 in range(j1.max() + 1):
677
+ if j2[j1 == i2].shape[0] == 1:
678
+ ax1.plot(
679
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], ".", color=color, lw=lw
680
+ )
452
681
  else:
453
682
  if legend and test is None:
454
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
455
- color=color, label=lname, lw=lw)
456
- test=1
457
- ax1.plot(j2[j1==i2]+n,tmp[i0,j1==i2], \
458
- color=color, lw=lw)
459
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
460
- tabx=tabx+[k+n for k in j2[j1==i2]]
461
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
462
- tab2nx=tab2nx+['%d'%(i2)]
463
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
464
- n=n+j2[j1==i2].shape[0]-1
465
- plt.yscale('log')
466
- ax1.set_xlim(-1,n+3)
683
+ ax1.plot(
684
+ j2[j1 == i2] + n,
685
+ tmp[i0, j1 == i2],
686
+ color=color,
687
+ label=lname,
688
+ lw=lw,
689
+ )
690
+ test = 1
691
+ ax1.plot(
692
+ j2[j1 == i2] + n, tmp[i0, j1 == i2], color=color, lw=lw
693
+ )
694
+ tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
695
+ tabx = tabx + [k + n for k in j2[j1 == i2]]
696
+ tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
697
+ tab2nx = tab2nx + ["%d" % (i2)]
698
+ ax1.axvline((j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray")
699
+ n = n + j2[j1 == i2].shape[0] - 1
700
+ plt.yscale("log")
701
+ ax1.set_xlim(-1, n + 3)
467
702
  ax1.set_xticks(tabx)
468
- ax1.set_xticklabels(tabnx,fontsize=6)
469
- ax1.set_xlabel(r"$j_{2}$",fontsize=6)
470
-
703
+ ax1.set_xticklabels(tabnx, fontsize=6)
704
+ ax1.set_xlabel(r"$j_{2}$", fontsize=6)
705
+
471
706
  # Move twinned axis ticks and label from top to bottom
472
707
  ax2.xaxis.set_ticks_position("bottom")
473
708
  ax2.xaxis.set_label_position("bottom")
@@ -475,7 +710,7 @@ class scat1D:
475
710
  # Offset the twin axis below the host
476
711
  ax2.spines["bottom"].set_position(("axes", -0.15))
477
712
 
478
- # Turn on the frame for the twin axis, but then hide all
713
+ # Turn on the frame for the twin axis, but then hide all
479
714
  # but the bottom spine
480
715
  ax2.set_frame_on(True)
481
716
  ax2.patch.set_visible(False)
@@ -483,127 +718,182 @@ class scat1D:
483
718
  for sp in ax2.spines.values():
484
719
  sp.set_visible(False)
485
720
  ax2.spines["bottom"].set_visible(True)
486
- ax2.set_xlim(0,n+3)
721
+ ax2.set_xlim(0, n + 3)
487
722
  ax2.set_xticks(tab2x)
488
- ax2.set_xticklabels(tab2nx,fontsize=6)
489
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
723
+ ax2.set_xticklabels(tab2nx, fontsize=6)
724
+ ax2.set_xlabel(r"$j_{1}$", fontsize=6)
490
725
  ax1.legend(frameon=0)
491
-
492
- def save(self,filename):
493
- outlist=[self.get_S0().numpy(), \
494
- self.get_S1().numpy(), \
495
- self.get_S2().numpy(), \
496
- self.get_S2L().numpy(), \
497
- self.get_P00().numpy(), \
498
- self.j1, \
499
- self.j2]
500
-
501
- myout=open("%s.pkl"%(filename),"wb")
502
- pickle.dump(outlist,myout)
726
+
727
+ def save(self, filename):
728
+ outlist = [
729
+ self.get_S0().numpy(),
730
+ self.get_S1().numpy(),
731
+ self.get_S2().numpy(),
732
+ self.get_S2L().numpy(),
733
+ self.get_P00().numpy(),
734
+ self.j1,
735
+ self.j2,
736
+ ]
737
+
738
+ myout = open("%s.pkl" % (filename), "wb")
739
+ pickle.dump(outlist, myout)
503
740
  myout.close()
504
741
 
505
-
506
- def read(self,filename):
507
-
508
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
742
+ def read(self, filename):
743
+
744
+ outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
745
+
746
+ return scat1D(
747
+ outlist[4],
748
+ outlist[0],
749
+ outlist[1],
750
+ outlist[2],
751
+ outlist[3],
752
+ outlist[5],
753
+ outlist[6],
754
+ backend=bk.foscat_backend("numpy"),
755
+ )
509
756
 
510
- return scat1D(outlist[4],outlist[0],outlist[1],outlist[2],outlist[3],outlist[5],outlist[6],backend=bk.foscat_backend('numpy'))
511
-
512
- def get_np(self,x):
757
+ def get_np(self, x):
513
758
  if isinstance(x, np.ndarray):
514
759
  return x
515
760
  else:
516
761
  return x.numpy()
517
762
 
518
763
  def std(self):
519
- return np.sqrt(((abs(self.get_np(self.S0)).std())**2+ \
520
- (abs(self.get_np(self.S1)).std())**2+ \
521
- (abs(self.get_np(self.S2)).std())**2+ \
522
- (abs(self.get_np(self.S2L)).std())**2+ \
523
- (abs(self.get_np(self.P00)).std())**2)/4)
764
+ return np.sqrt(
765
+ (
766
+ (abs(self.get_np(self.S0)).std()) ** 2
767
+ + (abs(self.get_np(self.S1)).std()) ** 2
768
+ + (abs(self.get_np(self.S2)).std()) ** 2
769
+ + (abs(self.get_np(self.S2L)).std()) ** 2
770
+ + (abs(self.get_np(self.P00)).std()) ** 2
771
+ )
772
+ / 4
773
+ )
524
774
 
525
775
  def mean(self):
526
- return abs(self.get_np(self.S0).mean()+ \
527
- self.get_np(self.S1).mean()+ \
528
- self.get_np(self.S2).mean()+ \
529
- self.get_np(self.S2L).mean()+ \
530
- self.get_np(self.P00).mean())/3
531
-
532
-
776
+ return (
777
+ abs(
778
+ self.get_np(self.S0).mean()
779
+ + self.get_np(self.S1).mean()
780
+ + self.get_np(self.S2).mean()
781
+ + self.get_np(self.S2L).mean()
782
+ + self.get_np(self.P00).mean()
783
+ )
784
+ / 3
785
+ )
533
786
 
534
787
  # ---------------------------------------------−---------
535
- def cleanval(self,x):
536
- x=x.numpy()
537
- x[np.isfinite(x)==False]=np.median(x[np.isfinite(x)])
788
+ def cleanval(self, x):
789
+ x = x.numpy()
790
+ x[~np.isfinite(x)] = np.median(x[np.isfinite(x)])
538
791
  return x
539
792
 
540
793
  def filter_inf(self):
541
- S1 = self.cleanval(self.S1)
542
- S0 = self.cleanval(self.S0)
794
+ S1 = self.cleanval(self.S1)
795
+ S0 = self.cleanval(self.S0)
543
796
  P00 = self.cleanval(self.P00)
544
- S2 = self.cleanval(self.S2)
797
+ S2 = self.cleanval(self.S2)
545
798
  S2L = self.cleanval(self.S2L)
546
799
 
547
- return scat1D(P00,S0,S1,S2,S2L,self.j1,self.j2,backend=self.backend)
800
+ return scat1D(P00, S0, S1, S2, S2L, self.j1, self.j2, backend=self.backend)
548
801
 
549
802
  # ---------------------------------------------−---------
550
- def interp(self,nscale,extend=False,constant=False,threshold=1E30):
551
-
552
- if nscale+2>self.S1.shape[1]:
553
- print('Can not *interp* %d with a statistic described over %d'%(nscale,self.S1.shape[1]))
554
- return scat1D(self.P00,self.S0,self.S1,self.S2,self.S2L,self.j1,self.j2,backend=self.backend)
555
- if isinstance(self.S1,np.ndarray):
556
- s1=self.S1
557
- p0=self.P00
558
- s2=self.S2
559
- s2l=self.S2L
803
+ def interp(self, nscale, extend=False, constant=False, threshold=1e30):
804
+
805
+ if nscale + 2 > self.S1.shape[1]:
806
+ print(
807
+ "Can not *interp* %d with a statistic described over %d"
808
+ % (nscale, self.S1.shape[1])
809
+ )
810
+ return scat1D(
811
+ self.P00,
812
+ self.S0,
813
+ self.S1,
814
+ self.S2,
815
+ self.S2L,
816
+ self.j1,
817
+ self.j2,
818
+ backend=self.backend,
819
+ )
820
+ if isinstance(self.S1, np.ndarray):
821
+ s1 = self.S1
822
+ p0 = self.P00
823
+ s2 = self.S2
824
+ s2l = self.S2L
560
825
  else:
561
- s1=self.S1.numpy()
562
- p0=self.P00.numpy()
563
- s2=self.S2.numpy()
564
- s2l=self.S2L.numpy()
565
-
566
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
567
-
568
- if isinstance(threshold,scat1D):
569
- if isinstance(threshold.S1,np.ndarray):
570
- s1th=threshold.S1
571
- p0th=threshold.P00
572
- s2th=threshold.S2
573
- s2lth=threshold.S2L
826
+ s1 = self.S1.numpy()
827
+ p0 = self.P00.numpy()
828
+ s2 = self.S2.numpy()
829
+ s2l = self.S2L.numpy()
830
+
831
+ if isinstance(threshold, scat1D):
832
+ if isinstance(threshold.S1, np.ndarray):
833
+ s1th = threshold.S1
834
+ p0th = threshold.P00
835
+ s2th = threshold.S2
836
+ s2lth = threshold.S2L
574
837
  else:
575
- s1th=threshold.S1.numpy()
576
- p0th=threshold.P00.numpy()
577
- s2th=threshold.S2.numpy()
578
- s2lth=threshold.S2L.numpy()
838
+ s1th = threshold.S1.numpy()
839
+ p0th = threshold.P00.numpy()
840
+ s2th = threshold.S2.numpy()
841
+ s2lth = threshold.S2L.numpy()
579
842
  else:
580
- s1th=threshold+0*s1
581
- p0th=threshold+0*p0
582
- s2th=threshold+0*s2
583
- s2lth=threshold+0*s2l
843
+ s1th = threshold + 0 * s1
844
+ p0th = threshold + 0 * p0
845
+ s2th = threshold + 0 * s2
846
+ s2lth = threshold + 0 * s2l
584
847
 
585
848
  for k in range(nscale):
586
849
  if constant:
587
- s1[:,nscale-1-k,:]=s1[:,nscale-k,:]
588
- p0[:,nscale-1-k,:]=p0[:,nscale-k,:]
850
+ s1[:, nscale - 1 - k, :] = s1[:, nscale - k, :]
851
+ p0[:, nscale - 1 - k, :] = p0[:, nscale - k, :]
589
852
  else:
590
- idx=np.where((s1[:,nscale+1-k,:]>0)*(s1[:,nscale+2-k,:]>0)*(s1[:,nscale-k,:]<s1th[:,nscale-k,:]))
591
- if len(idx[0])>0:
592
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(s1[idx[0],nscale+1-k,idx[1]])-2*np.log(s1[idx[0],nscale+2-k,idx[1]]))
593
- idx=np.where((s1[:,nscale-k,:]>0)*(s1[:,nscale+1-k,:]>0)*(s1[:,nscale-1-k,:]<s1th[:,nscale-1-k,:]))
594
- if len(idx[0])>0:
595
- s1[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(s1[idx[0],nscale-k,idx[1]])-np.log(s1[idx[0],nscale+1-k,idx[1]]))
596
-
597
- idx=np.where((p0[:,nscale+1-k,:]>0)*(p0[:,nscale+2-k,:]>0)*(p0[:,nscale-k,:]<p0th[:,nscale-k,:]))
598
- if len(idx[0])>0:
599
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(3*np.log(p0[idx[0],nscale+1-k,idx[1]])-2*np.log(p0[idx[0],nscale+2-k,idx[1]]))
600
-
601
- idx=np.where((p0[:,nscale-k,:]>0)*(p0[:,nscale+1-k,:]>0)*(p0[:,nscale-1-k,:]<p0th[:,nscale-1-k,:]))
602
- if len(idx[0])>0:
603
- p0[idx[0],nscale-1-k,idx[1]]=np.exp(2*np.log(p0[idx[0],nscale-k,idx[1]])-np.log(p0[idx[0],nscale+1-k,idx[1]]))
604
-
605
-
606
- j1,j2=self.get_j_idx()
853
+ idx = np.where(
854
+ (s1[:, nscale + 1 - k, :] > 0)
855
+ * (s1[:, nscale + 2 - k, :] > 0)
856
+ * (s1[:, nscale - k, :] < s1th[:, nscale - k, :])
857
+ )
858
+ if len(idx[0]) > 0:
859
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
860
+ 3 * np.log(s1[idx[0], nscale + 1 - k, idx[1]])
861
+ - 2 * np.log(s1[idx[0], nscale + 2 - k, idx[1]])
862
+ )
863
+ idx = np.where(
864
+ (s1[:, nscale - k, :] > 0)
865
+ * (s1[:, nscale + 1 - k, :] > 0)
866
+ * (s1[:, nscale - 1 - k, :] < s1th[:, nscale - 1 - k, :])
867
+ )
868
+ if len(idx[0]) > 0:
869
+ s1[idx[0], nscale - 1 - k, idx[1]] = np.exp(
870
+ 2 * np.log(s1[idx[0], nscale - k, idx[1]])
871
+ - np.log(s1[idx[0], nscale + 1 - k, idx[1]])
872
+ )
873
+
874
+ idx = np.where(
875
+ (p0[:, nscale + 1 - k, :] > 0)
876
+ * (p0[:, nscale + 2 - k, :] > 0)
877
+ * (p0[:, nscale - k, :] < p0th[:, nscale - k, :])
878
+ )
879
+ if len(idx[0]) > 0:
880
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
881
+ 3 * np.log(p0[idx[0], nscale + 1 - k, idx[1]])
882
+ - 2 * np.log(p0[idx[0], nscale + 2 - k, idx[1]])
883
+ )
884
+
885
+ idx = np.where(
886
+ (p0[:, nscale - k, :] > 0)
887
+ * (p0[:, nscale + 1 - k, :] > 0)
888
+ * (p0[:, nscale - 1 - k, :] < p0th[:, nscale - 1 - k, :])
889
+ )
890
+ if len(idx[0]) > 0:
891
+ p0[idx[0], nscale - 1 - k, idx[1]] = np.exp(
892
+ 2 * np.log(p0[idx[0], nscale - k, idx[1]])
893
+ - np.log(p0[idx[0], nscale + 1 - k, idx[1]])
894
+ )
895
+
896
+ j1, j2 = self.get_j_idx()
607
897
 
608
898
  for k in range(nscale):
609
899
 
@@ -616,449 +906,641 @@ class scat1D:
616
906
  s2l[:,i0]=np.exp(2*np.log(s2l[:,i1])-np.log(s2l[:,i2]))
617
907
  """
618
908
 
619
- for l in range(nscale-k):
620
- i0=np.where((j1==nscale-1-k-l)*(j2==nscale-1-k))[0]
621
- i1=np.where((j1==nscale-1-k-l)*(j2==nscale -k))[0]
622
- i2=np.where((j1==nscale-1-k-l)*(j2==nscale+1-k))[0]
623
- i3=np.where((j1==nscale-1-k-l)*(j2==nscale+2-k))[0]
624
-
909
+ for l_orient in range(nscale - k):
910
+ i0 = np.where(
911
+ (j1 == nscale - 1 - k - l_orient) * (j2 == nscale - 1 - k)
912
+ )[0]
913
+ i1 = np.where((j1 == nscale - 1 - k - l_orient) * (j2 == nscale - k))[0]
914
+ i2 = np.where(
915
+ (j1 == nscale - 1 - k - l_orient) * (j2 == nscale + 1 - k)
916
+ )[0]
917
+ i3 = np.where(
918
+ (j1 == nscale - 1 - k - l_orient) * (j2 == nscale + 2 - k)
919
+ )[0]
920
+
625
921
  if constant:
626
- s2[:,i0]=s2[:,i1]
627
- s2l[:,i0]=s2l[:,i1]
922
+ s2[:, i0] = s2[:, i1]
923
+ s2l[:, i0] = s2l[:, i1]
628
924
  else:
629
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
630
- if len(idx[0])>0:
631
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
632
-
633
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
634
- if len(idx[0])>0:
635
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
636
-
637
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
638
- if len(idx[0])>0:
639
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
640
-
641
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
642
- if len(idx[0])>0:
643
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
644
-
925
+ idx = np.where(
926
+ (s2[:, i2] > 0) * (s2[:, i3] > 0) * (s2[:, i2] < s2th[:, i2])
927
+ )
928
+ if len(idx[0]) > 0:
929
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
930
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
931
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
932
+ )
933
+
934
+ idx = np.where(
935
+ (s2[:, i1] > 0) * (s2[:, i2] > 0) * (s2[:, i1] < s2th[:, i1])
936
+ )
937
+ if len(idx[0]) > 0:
938
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
939
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
940
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
941
+ )
942
+
943
+ idx = np.where(
944
+ (s2l[:, i2] > 0)
945
+ * (s2l[:, i3] > 0)
946
+ * (s2l[:, i2] < s2lth[:, i2])
947
+ )
948
+ if len(idx[0]) > 0:
949
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
950
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
951
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
952
+ )
953
+
954
+ idx = np.where(
955
+ (s2l[:, i1] > 0)
956
+ * (s2l[:, i2] > 0)
957
+ * (s2l[:, i1] < s2lth[:, i1])
958
+ )
959
+ if len(idx[0]) > 0:
960
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
961
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
962
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
963
+ )
964
+
645
965
  if extend:
646
966
  for k in range(nscale):
647
- for l in range(1,nscale):
648
- i0=np.where((j1==2*nscale-1-k)*(j2==2*nscale-1-k-l))[0]
649
- i1=np.where((j1==2*nscale-1-k)*(j2==2*nscale -k-l))[0]
650
- i2=np.where((j1==2*nscale-1-k)*(j2==2*nscale+1-k-l))[0]
651
- i3=np.where((j1==2*nscale-1-k)*(j2==2*nscale+2-k-l))[0]
967
+ for l_orient in range(1, nscale):
968
+ i0 = np.where(
969
+ (j1 == 2 * nscale - 1 - k)
970
+ * (j2 == 2 * nscale - 1 - k - l_orient)
971
+ )[0]
972
+ i1 = np.where(
973
+ (j1 == 2 * nscale - 1 - k) * (j2 == 2 * nscale - k - l_orient)
974
+ )[0]
975
+ i2 = np.where(
976
+ (j1 == 2 * nscale - 1 - k)
977
+ * (j2 == 2 * nscale + 1 - k - l_orient)
978
+ )[0]
979
+ i3 = np.where(
980
+ (j1 == 2 * nscale - 1 - k)
981
+ * (j2 == 2 * nscale + 2 - k - l_orient)
982
+ )[0]
652
983
  if constant:
653
- s2[:,i0]=s2[:,i1]
654
- s2l[:,i0]=s2l[:,i1]
984
+ s2[:, i0] = s2[:, i1]
985
+ s2l[:, i0] = s2l[:, i1]
655
986
  else:
656
- idx=np.where((s2[:,i2]>0)*(s2[:,i3]>0)*(s2[:,i2]<s2th[:,i2]))
657
- print(i0,i2)
658
- if len(idx[0])>0:
659
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2[idx[0],i2,idx[1],idx[2]])-2*np.log(s2[idx[0],i3,idx[1],idx[2]]))
660
- idx=np.where((s2[:,i1]>0)*(s2[:,i2]>0)*(s2[:,i1]<s2th[:,i1]))
661
- if len(idx[0])>0:
662
- s2[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2[idx[0],i1,idx[1],idx[2]])-np.log(s2[idx[0],i2,idx[1],idx[2]]))
663
-
664
- idx=np.where((s2l[:,i2]>0)*(s2l[:,i3]>0)*(s2l[:,i2]<s2lth[:,i2]))
665
- if len(idx[0])>0:
666
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(3*np.log(s2l[idx[0],i2,idx[1],idx[2]])-2*np.log(s2l[idx[0],i3,idx[1],idx[2]]))
667
- idx=np.where((s2l[:,i1]>0)*(s2l[:,i2]>0)*(s2l[:,i1]<s2lth[:,i1]))
668
- if len(idx[0])>0:
669
- s2l[idx[0],i0,idx[1],idx[2]]=np.exp(2*np.log(s2l[idx[0],i1,idx[1],idx[2]])-np.log(s2l[idx[0],i2,idx[1],idx[2]]))
670
-
671
- s1[np.isnan(s1)]=0.0
672
- p0[np.isnan(p0)]=0.0
673
- s2[np.isnan(s2)]=0.0
674
- s2l[np.isnan(s2l)]=0.0
675
- print(s1.sum(),p0.sum(),s2.sum(),s2l.sum())
676
-
677
- return scat1D(self.backend.constant(p0),self.S0,
678
- self.backend.constant(s1),
679
- self.backend.constant(s2),
680
- self.backend.constant(s2l),self.j1,self.j2,backend=self.backend)
987
+ idx = np.where(
988
+ (s2[:, i2] > 0)
989
+ * (s2[:, i3] > 0)
990
+ * (s2[:, i2] < s2th[:, i2])
991
+ )
992
+ if len(idx[0]) > 0:
993
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
994
+ 3 * np.log(s2[idx[0], i2, idx[1], idx[2]])
995
+ - 2 * np.log(s2[idx[0], i3, idx[1], idx[2]])
996
+ )
997
+ idx = np.where(
998
+ (s2[:, i1] > 0)
999
+ * (s2[:, i2] > 0)
1000
+ * (s2[:, i1] < s2th[:, i1])
1001
+ )
1002
+ if len(idx[0]) > 0:
1003
+ s2[idx[0], i0, idx[1], idx[2]] = np.exp(
1004
+ 2 * np.log(s2[idx[0], i1, idx[1], idx[2]])
1005
+ - np.log(s2[idx[0], i2, idx[1], idx[2]])
1006
+ )
1007
+
1008
+ idx = np.where(
1009
+ (s2l[:, i2] > 0)
1010
+ * (s2l[:, i3] > 0)
1011
+ * (s2l[:, i2] < s2lth[:, i2])
1012
+ )
1013
+ if len(idx[0]) > 0:
1014
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1015
+ 3 * np.log(s2l[idx[0], i2, idx[1], idx[2]])
1016
+ - 2 * np.log(s2l[idx[0], i3, idx[1], idx[2]])
1017
+ )
1018
+ idx = np.where(
1019
+ (s2l[:, i1] > 0)
1020
+ * (s2l[:, i2] > 0)
1021
+ * (s2l[:, i1] < s2lth[:, i1])
1022
+ )
1023
+ if len(idx[0]) > 0:
1024
+ s2l[idx[0], i0, idx[1], idx[2]] = np.exp(
1025
+ 2 * np.log(s2l[idx[0], i1, idx[1], idx[2]])
1026
+ - np.log(s2l[idx[0], i2, idx[1], idx[2]])
1027
+ )
1028
+
1029
+ s1[np.isnan(s1)] = 0.0
1030
+ p0[np.isnan(p0)] = 0.0
1031
+ s2[np.isnan(s2)] = 0.0
1032
+ s2l[np.isnan(s2l)] = 0.0
1033
+
1034
+ return scat1D(
1035
+ self.backend.constant(p0),
1036
+ self.S0,
1037
+ self.backend.constant(s1),
1038
+ self.backend.constant(s2),
1039
+ self.backend.constant(s2l),
1040
+ self.j1,
1041
+ self.j2,
1042
+ backend=self.backend,
1043
+ )
681
1044
 
682
1045
  # ---------------------------------------------−---------
683
- def model(self,i__y,add=0,dx=3,dell=2,weigth=None,inverse=False):
1046
+ def model(self, i__y, add=0, dx=3, dell=2, weigth=None, inverse=False):
684
1047
 
685
- if i__y.shape[0]<dx+1:
686
- l__dx=i__y.shape[0]-1
1048
+ if i__y.shape[0] < dx + 1:
1049
+ l__dx = i__y.shape[0] - 1
687
1050
  else:
688
- l__dx=dx
1051
+ l__dx = dx
689
1052
 
690
- if i__y.shape[0]<dell:
691
- l__dell=0
1053
+ if i__y.shape[0] < dell:
1054
+ l__dell = 0
692
1055
  else:
693
- l__dell=dell
1056
+ l__dell = dell
694
1057
 
695
- if l__dx<2:
696
- res=np.zeros([i__y.shape[0]+add])
1058
+ if l__dx < 2:
1059
+ res = np.zeros([i__y.shape[0] + add])
697
1060
  if inverse:
698
- res[:-add]=i__y
1061
+ res[:-add] = i__y
699
1062
  else:
700
- res[add:]=i__y[0:]
1063
+ res[add:] = i__y[0:]
701
1064
  return res
702
1065
 
703
1066
  if weigth is None:
704
- w=2**(np.arange(l__dx))
1067
+ w = 2 ** (np.arange(l__dx))
705
1068
  else:
706
1069
  if not inverse:
707
- w=weigth[0:l__dx]
1070
+ w = weigth[0:l__dx]
708
1071
  else:
709
- w=weigth[-l__dx:]
1072
+ w = weigth[-l__dx:]
710
1073
 
711
- x=np.arange(l__dx)+1
1074
+ x = np.arange(l__dx) + 1
712
1075
  if not inverse:
713
- y=np.log(i__y[1:l__dx+1])
1076
+ y = np.log(i__y[1 : l__dx + 1])
714
1077
  else:
715
- y=np.log(i__y[-(l__dx+1):-1])
1078
+ y = np.log(i__y[-(l__dx + 1) : -1])
716
1079
 
717
- r=np.polyfit(x,y,1,w=w)
1080
+ r = np.polyfit(x, y, 1, w=w)
718
1081
 
719
1082
  if inverse:
720
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-1)+r[1])
721
- res[:-(l__dell+add)]=i__y[:-l__dell]
1083
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - 1) + r[1])
1084
+ res[: -(l__dell + add)] = i__y[:-l__dell]
722
1085
  else:
723
- res=np.exp(r[0]*(np.arange(i__y.shape[0]+add)-add)+r[1])
724
- res[l__dell+add:]=i__y[l__dell:]
1086
+ res = np.exp(r[0] * (np.arange(i__y.shape[0] + add) - add) + r[1])
1087
+ res[l__dell + add :] = i__y[l__dell:]
725
1088
  return res
726
1089
 
727
- def findn(self,n):
728
- d=np.sqrt(1+8*n)
729
- return int((d-1)/2)
1090
+ def findn(self, n):
1091
+ d = np.sqrt(1 + 8 * n)
1092
+ return int((d - 1) / 2)
730
1093
 
731
- def findidx(self,s2):
732
- i1=np.zeros([s2.shape[1]],dtype='int')
733
- i2=np.zeros([s2.shape[1]],dtype='int')
734
- n=0
1094
+ def findidx(self, s2):
1095
+ i1 = np.zeros([s2.shape[1]], dtype="int")
1096
+ i2 = np.zeros([s2.shape[1]], dtype="int")
1097
+ n = 0
735
1098
  for k in range(self.findn(s2.shape[1])):
736
- i1[n:n+k+1]=np.arange(k+1)
737
- i2[n:n+k+1]=k
738
- n=n+k+1
739
- return i1,i2
740
-
741
- def extrapol_s2(self,add,lnorm=1):
742
- if lnorm==1:
743
- s2=self.S2.numpy()
744
- if lnorm==2:
745
- s2=self.S2L.numpy()
746
- i1,i2=self.findidx(s2)
747
-
748
- so2=np.zeros([s2.shape[0],(self.findn(s2.shape[1])+add)*(self.findn(s2.shape[1])+add+1)//2,s2.shape[2],s2.shape[3]])
749
- oi1,oi2=self.findidx(so2)
750
- for l in range(s2.shape[0]):
1099
+ i1[n : n + k + 1] = np.arange(k + 1)
1100
+ i2[n : n + k + 1] = k
1101
+ n = n + k + 1
1102
+ return i1, i2
1103
+
1104
+ def extrapol_s2(self, add, lnorm=1):
1105
+ if lnorm == 1:
1106
+ s2 = self.S2.numpy()
1107
+ if lnorm == 2:
1108
+ s2 = self.S2L.numpy()
1109
+ i1, i2 = self.findidx(s2)
1110
+
1111
+ so2 = np.zeros(
1112
+ [
1113
+ s2.shape[0],
1114
+ (self.findn(s2.shape[1]) + add)
1115
+ * (self.findn(s2.shape[1]) + add + 1)
1116
+ // 2,
1117
+ s2.shape[2],
1118
+ s2.shape[3],
1119
+ ]
1120
+ )
1121
+ oi1, oi2 = self.findidx(so2)
1122
+ for l_orient in range(s2.shape[0]):
751
1123
  for k in range(self.findn(s2.shape[1])):
752
1124
  for i in range(s2.shape[2]):
753
1125
  for j in range(s2.shape[3]):
754
- tmp=self.model(s2[l,i2==k,i,j],dx=4,dell=1,add=add,weigth=np.array([1,2,2,2]))
755
- tmp[np.isnan(tmp)]=0.0
756
- so2[l,oi2==k+add,i,j]=tmp
757
-
758
-
759
- for l in range(s2.shape[0]):
760
- for k in range(add+1,-1,-1):
761
- lidx=np.where(oi2-oi1==k)[0]
762
- lidx2=np.where(oi2-oi1==k+1)[0]
1126
+ tmp = self.model(
1127
+ s2[l_orient, i2 == k, i, j],
1128
+ dx=4,
1129
+ dell=1,
1130
+ add=add,
1131
+ weigth=np.array([1, 2, 2, 2]),
1132
+ )
1133
+ tmp[np.isnan(tmp)] = 0.0
1134
+ so2[l_orient, oi2 == k + add, i, j] = tmp
1135
+
1136
+ for l_orient in range(s2.shape[0]):
1137
+ for k in range(add + 1, -1, -1):
1138
+ lidx = np.where(oi2 - oi1 == k)[0]
1139
+ lidx2 = np.where(oi2 - oi1 == k + 1)[0]
763
1140
  for i in range(s2.shape[2]):
764
1141
  for j in range(s2.shape[3]):
765
- so2[l,lidx[0:add+2-k],i,j]=so2[l,lidx2[0:add+2-k],i,j]
1142
+ so2[l_orient, lidx[0 : add + 2 - k], i, j] = so2[
1143
+ l_orient, lidx2[0 : add + 2 - k], i, j
1144
+ ]
766
1145
 
767
- return(so2)
1146
+ return so2
768
1147
 
769
- def extrapol_s1(self,i_s1,add):
770
- s1=i_s1.numpy()
771
- so1=np.zeros([s1.shape[0],s1.shape[1]+add,s1.shape[2]])
1148
+ def extrapol_s1(self, i_s1, add):
1149
+ s1 = i_s1.numpy()
1150
+ so1 = np.zeros([s1.shape[0], s1.shape[1] + add, s1.shape[2]])
772
1151
  for k in range(s1.shape[0]):
773
1152
  for i in range(s1.shape[2]):
774
- so1[k,:,i]=self.model(s1[k,:,i],dx=4,dell=1,add=add)
775
- so1[k,np.isnan(so1[k,:,i]),i]=0.0
1153
+ so1[k, :, i] = self.model(s1[k, :, i], dx=4, dell=1, add=add)
1154
+ so1[k, np.isnan(so1[k, :, i]), i] = 0.0
776
1155
  return so1
777
1156
 
778
- def extrapol(self,add):
779
- return scat1D(self.extrapol_s1(self.P00,add), \
780
- self.S0, \
781
- self.extrapol_s1(self.S1,add), \
782
- self.extrapol_s2(add,lnorm=1), \
783
- self.extrapol_s2(add,lnorm=2),self.j1,self.j2,backend=self.backend)
784
-
785
-
786
-
787
-
788
-
1157
+ def extrapol(self, add):
1158
+ return scat1D(
1159
+ self.extrapol_s1(self.P00, add),
1160
+ self.S0,
1161
+ self.extrapol_s1(self.S1, add),
1162
+ self.extrapol_s2(add, lnorm=1),
1163
+ self.extrapol_s2(add, lnorm=2),
1164
+ self.j1,
1165
+ self.j2,
1166
+ backend=self.backend,
1167
+ )
1168
+
1169
+
789
1170
  class funct(FOC.FoCUS):
790
1171
 
791
- def fill(self,im,nullval=0):
792
- return self.fill_1d(im,nullval=nullval)
793
-
794
- def ud_grade(self,im,nout,axis=0):
795
- return self.ud_grade_1d(im,nout,axis=axis)
796
-
797
- def up_grade(self,im,nout,axis=0):
798
- return self.up_grade_1d(im,nout,axis=axis)
799
-
800
- def smooth(self,data,axis=0):
801
- return self.smooth_1d(data,axis=axis)
802
-
803
- def convol(self,data,axis=0):
804
- return self.convol_1d(data,axis=axis)
805
-
806
- def eval(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6,Add_R45=False,axis=0):
1172
+ def fill(self, im, nullval=0):
1173
+ return self.fill_1d(im, nullval=nullval)
1174
+
1175
+ def ud_grade(self, im, nout, axis=0):
1176
+ return self.ud_grade_1d(im, nout, axis=axis)
1177
+
1178
+ def up_grade(self, im, nout, axis=0):
1179
+ return self.up_grade_1d(im, nout, axis=axis)
1180
+
1181
+ def smooth(self, data, axis=0):
1182
+ return self.smooth_1d(data, axis=axis)
1183
+
1184
+ def convol(self, data, axis=0):
1185
+ return self.convol_1d(data, axis=axis)
1186
+
1187
+ def eval(
1188
+ self,
1189
+ image1,
1190
+ image2=None,
1191
+ mask=None,
1192
+ Auto=True,
1193
+ s0_off=1e-6,
1194
+ Add_R45=False,
1195
+ axis=0,
1196
+ ):
807
1197
  # Check input consistency
808
1198
  if mask is not None:
809
- if list(image1.shape)!=list(mask.shape)[1:]:
810
- print('The mask should have the same size than the input timeline to eval Scattering')
811
- exit(0)
812
-
1199
+ if len(image1.shape) == 1:
1200
+ if image1.shape[0] != mask.shape[1]:
1201
+ print(
1202
+ "The mask should have the same size than the input timeline to eval Scattering"
1203
+ )
1204
+ return None
1205
+ else:
1206
+ if image1.shape[1] != mask.shape[1]:
1207
+ print(
1208
+ "The mask should have the same size than the input timeline to eval Scattering"
1209
+ )
1210
+ return None
1211
+
813
1212
  ### AUTO OR CROSS
814
1213
  cross = False
815
1214
  if image2 is not None:
816
1215
  cross = True
817
- all_cross=not Auto
818
- else:
819
- all_cross=False
820
-
1216
+
821
1217
  # determine jmax and nside corresponding to the input map
822
1218
  im_shape = image1.shape
823
1219
 
824
- nside=im_shape[axis]
825
- npix=nside
826
-
827
- jmax=int(np.log(nside)/np.log(2)) #-self.OSTEP
1220
+ nside = im_shape[len(image1.shape) - 1]
1221
+ npix = nside
1222
+
1223
+ jmax = int(np.log(nside) / np.log(2)) # -self.OSTEP
828
1224
 
829
1225
  ### LOCAL VARIABLES (IMAGES and MASK)
830
1226
  # Check if image1 is [Npix] or [Nbatch,Npix]
831
- if len(image1.shape)==1:
1227
+ if len(image1.shape) == 1:
832
1228
  # image1 is [Nbatch, Npix]
833
- I1 = self.backend.bk_cast(self.backend.bk_expand_dims(image1,0)) # Local image1 [Nbatch, Npix]
1229
+ I1 = self.backend.bk_cast(
1230
+ self.backend.bk_expand_dims(image1, 0)
1231
+ ) # Local image1 [Nbatch, Npix]
834
1232
  if cross:
835
- I2 = self.backend.bk_cast(self.backend.bk_expand_dims(image2,0)) # Local image2 [Nbatch, Npix]
1233
+ I2 = self.backend.bk_cast(
1234
+ self.backend.bk_expand_dims(image2, 0)
1235
+ ) # Local image2 [Nbatch, Npix]
836
1236
  else:
837
- I1=self.backend.bk_cast(image1)
1237
+ I1 = self.backend.bk_cast(image1)
838
1238
  if cross:
839
- I2=self.backend.bk_cast(image2)
840
-
1239
+ I2 = self.backend.bk_cast(image2)
1240
+
841
1241
  # self.mask is [Nmask, Npix]
842
-
1242
+
843
1243
  if mask is None:
844
1244
  vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
845
1245
  else:
846
1246
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
847
1247
 
848
- if self.KERNELSZ>3:
1248
+ if self.KERNELSZ > 3:
849
1249
  # if the kernel size is bigger than 3 increase the binning before smoothing
850
- l_image1=self.up_grade_1d(I1,nside*2,axis=axis+1)
851
- vmask=self.up_grade_1d(vmask,nside*2,axis=1)
852
-
1250
+ l_image1 = self.up_grade_1d(I1, nside * 2, axis=axis + 1)
1251
+ vmask = self.up_grade_1d(vmask, nside * 2, axis=1)
1252
+
853
1253
  if cross:
854
- l_image2=self.up_grade_1d(I2,nside*2,axis=axis+1)
1254
+ l_image2 = self.up_grade_1d(I2, nside * 2, axis=axis + 1)
855
1255
  else:
856
- l_image1=I1
1256
+ l_image1 = I1
857
1257
  if cross:
858
- l_image2=I2
859
-
860
- s0 = self.backend.bk_reduce_sum(l_image1*vmask,axis=axis+1)+s0_off
861
-
862
- if cross and Auto==False:
863
- s0 = self.backend.bk_concat([s0,self.backend.bk_reduce_sum(l_image2*vmask,axis=axis)+s0_off])
864
-
865
- s1=None
866
- s2=None
867
- s2l=None
868
- p00=None
869
- s2j1=None
870
- s2j2=None
871
-
872
- l2_image=None
873
- l2_image_imag=None
1258
+ l_image2 = I2
1259
+ if len(image1.shape) == 1:
1260
+ s0 = self.backend.bk_reduce_sum(l_image1 * vmask, axis=axis + 1)
1261
+ if cross and not Auto:
1262
+ s0 = self.backend.bk_concat(
1263
+ [s0, self.backend.bk_reduce_sum(l_image2 * vmask, axis=axis)]
1264
+ )
1265
+ else:
1266
+ lmask = self.backend.bk_expand_dims(vmask, 0)
1267
+ lim = self.backend.bk_expand_dims(l_image1, 1)
1268
+ s0 = self.backend.bk_reduce_sum(lim * lmask, axis=axis + 2)
1269
+ if cross and not Auto:
1270
+ lim = self.backend.bk_expand_dims(l_image2, 1)
1271
+ s0 = self.backend.bk_concat(
1272
+ [s0, self.backend.bk_reduce_sum(lim * lmask, axis=axis + 2)]
1273
+ )
1274
+
1275
+ s1 = None
1276
+ s2 = None
1277
+ s2l = None
1278
+ p00 = None
1279
+ s2j1 = None
1280
+ s2j2 = None
1281
+
1282
+ l2_image = None
874
1283
 
875
1284
  for j1 in range(jmax):
876
- if j1<jmax-self.OSTEP: # stop to add scales
1285
+ if j1 < jmax - self.OSTEP: # stop to add scales
877
1286
  # Convol image along the axis defined by 'axis' using the wavelet defined at
878
1287
  # the foscat initialisation
879
- #c_image_real is [....,Npix_j1,....,Norient]
880
- c_image1=self.convol_1d(l_image1,axis=axis+1)
1288
+ # c_image_real is [....,Npix_j1,....,Norient]
1289
+ c_image1 = self.convol_1d(l_image1, axis=axis + 1)
881
1290
  if cross:
882
- c_image2=self.convol_1d(l_image2,axis=axis+1)
1291
+ c_image2 = self.convol_1d(l_image2, axis=axis + 1)
883
1292
  else:
884
- c_image2=c_image1
1293
+ c_image2 = c_image1
885
1294
 
886
1295
  # Compute (a+ib)*(a+ib)* the last c_image column is the real and imaginary part
887
- conj=c_image1*self.backend.bk_conjugate(c_image2)
888
-
1296
+ conj = c_image1 * self.backend.bk_conjugate(c_image2)
1297
+
889
1298
  if Auto:
890
- conj=self.backend.bk_real(conj)
1299
+ conj = self.backend.bk_real(conj)
891
1300
 
892
1301
  # Compute l_p00 [....,....,Nmask,1]
893
- l_p00 = self.backend.bk_expand_dims(self.backend.bk_reduce_sum(conj*vmask,axis=1),-1)
1302
+ l_p00 = self.backend.bk_expand_dims(
1303
+ self.backend.bk_reduce_sum(conj * vmask, axis=1), -1
1304
+ )
894
1305
 
895
- conj=self.backend.bk_L1(conj)
1306
+ conj = self.backend.bk_L1(conj)
896
1307
 
897
1308
  # Compute l_s1 [....,....,Nmask,1]
898
- l_s1 = self.backend.bk_expand_dims(self.backend.bk_reduce_sum(conj*vmask,axis=1),-1)
899
-
900
- # Concat S1,P00 [....,....,Nmask,j1]
1309
+ l_s1 = self.backend.bk_expand_dims(
1310
+ self.backend.bk_reduce_sum(conj * vmask, axis=1), -1
1311
+ )
1312
+
1313
+ # Concat S1,P00 [....,....,Nmask,j1]
901
1314
  if s1 is None:
902
- s1=l_s1
903
- p00=l_p00
1315
+ s1 = l_s1
1316
+ p00 = l_p00
904
1317
  else:
905
- s1=self.backend.bk_concat([s1,l_s1],axis=-1)
906
- p00=self.backend.bk_concat([p00,l_p00],axis=-1)
1318
+ s1 = self.backend.bk_concat([s1, l_s1], axis=-1)
1319
+ p00 = self.backend.bk_concat([p00, l_p00], axis=-1)
907
1320
 
908
1321
  # Concat l2_image [....,Npix_j1,....,j1]
909
1322
  if l2_image is None:
910
- l2_image=self.backend.bk_expand_dims(conj,axis=1)
1323
+ l2_image = self.backend.bk_expand_dims(conj, axis=1)
911
1324
  else:
912
- l2_image=self.backend.bk_concat([self.backend.bk_expand_dims(conj,axis=1),l2_image],axis=1)
1325
+ l2_image = self.backend.bk_concat(
1326
+ [self.backend.bk_expand_dims(conj, axis=1), l2_image], axis=1
1327
+ )
913
1328
 
914
1329
  # Convol l2_image [....,Npix_j1,....,j1,Norient,Norient]
915
- c2_image=self.convol_1d(self.backend.bk_relu(l2_image),axis=axis+2)
1330
+ c2_image = self.convol_1d(self.backend.bk_relu(l2_image), axis=axis + 2)
916
1331
 
917
- conj2p=c2_image*self.backend.bk_conjugate(c2_image)
918
- conj2pl1=self.backend.bk_L1(conj2p)
1332
+ conj2p = c2_image * self.backend.bk_conjugate(c2_image)
1333
+ conj2pl1 = self.backend.bk_L1(conj2p)
919
1334
 
920
1335
  if Auto:
921
- conj2p=self.backend.bk_real(conj2p)
922
- conj2pl1=self.backend.bk_real(conj2pl1)
1336
+ conj2p = self.backend.bk_real(conj2p)
1337
+ conj2pl1 = self.backend.bk_real(conj2pl1)
923
1338
 
924
- c2_image=self.convol_1d(self.backend.bk_relu(-l2_image),axis=axis+2)
1339
+ c2_image = self.convol_1d(self.backend.bk_relu(-l2_image), axis=axis + 2)
925
1340
 
926
- conj2m=c2_image*self.backend.bk_conjugate(c2_image)
927
- conj2ml1=self.backend.bk_L1(conj2m)
1341
+ conj2m = c2_image * self.backend.bk_conjugate(c2_image)
1342
+ conj2ml1 = self.backend.bk_L1(conj2m)
928
1343
 
929
1344
  if Auto:
930
- conj2m=self.backend.bk_real(conj2m)
931
- conj2ml1=self.backend.bk_real(conj2ml1)
932
-
1345
+ conj2m = self.backend.bk_real(conj2m)
1346
+ conj2ml1 = self.backend.bk_real(conj2ml1)
1347
+
933
1348
  # Convol l_s2 [....,....,Nmask,j1]
934
- l_s2 = self.backend.bk_reduce_sum((conj2p-conj2m)*self.backend.bk_expand_dims(vmask,1),axis=axis+2)
935
- l_s2l1 = self.backend.bk_reduce_sum((conj2pl1-conj2ml1)*self.backend.bk_expand_dims(vmask,1),axis=axis+2)
1349
+ l_s2 = self.backend.bk_reduce_sum(
1350
+ (conj2p - conj2m) * self.backend.bk_expand_dims(vmask, 1), axis=axis + 2
1351
+ )
1352
+ l_s2l1 = self.backend.bk_reduce_sum(
1353
+ (conj2pl1 - conj2ml1) * self.backend.bk_expand_dims(vmask, 1),
1354
+ axis=axis + 2,
1355
+ )
936
1356
 
937
1357
  # Concat l_s2 [....,....,Nmask,j1*(j1+1)/2,Norient,Norient]
938
1358
  if s2 is None:
939
- s2l=l_s2
940
- s2=l_s2l1
941
- s2j1=np.arange(l_s2.shape[axis],dtype='int')
942
- s2j2=j1*np.ones(l_s2.shape[axis],dtype='int')
1359
+ s2l = l_s2
1360
+ s2 = l_s2l1
1361
+ s2j1 = np.arange(l_s2.shape[axis], dtype="int")
1362
+ s2j2 = j1 * np.ones(l_s2.shape[axis], dtype="int")
943
1363
  else:
944
- s2=self.backend.bk_concat([s2,l_s2l1],axis=-1)
945
- s2l=self.backend.bk_concat([s2l,l_s2],axis=-1)
946
- s2j1=np.concatenate([s2j1,np.arange(l_s2.shape[axis+1],dtype='int')],0)
947
- s2j2=np.concatenate([s2j2,j1*np.ones(l_s2.shape[axis+1],dtype='int')],0)
948
-
949
- if j1!=jmax-1:
950
- # Rescale vmask [Nmask,Npix_j1//4]
951
- vmask = self.smooth_1d(vmask,axis=1)
952
- vmask = self.ud_grade_1d(vmask,vmask.shape[1]//2,axis=1)
1364
+ s2 = self.backend.bk_concat([s2, l_s2l1], axis=-1)
1365
+ s2l = self.backend.bk_concat([s2l, l_s2], axis=-1)
1366
+ s2j1 = np.concatenate(
1367
+ [s2j1, np.arange(l_s2.shape[axis + 1], dtype="int")], 0
1368
+ )
1369
+ s2j2 = np.concatenate(
1370
+ [s2j2, j1 * np.ones(l_s2.shape[axis + 1], dtype="int")], 0
1371
+ )
1372
+
1373
+ if j1 != jmax - 1:
1374
+ # Rescale vmask [Nmask,Npix_j1//4]
1375
+ vmask = self.smooth_1d(vmask, axis=1)
1376
+ vmask = self.ud_grade_1d(vmask, vmask.shape[1] // 2, axis=1)
953
1377
  if self.mask_thres is not None:
954
- vmask = self.backend.bk_threshold(vmask,self.mask_thres)
955
-
956
- # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
957
- l2_image = self.smooth_1d(l2_image,axis=axis+2)
958
- l2_image = self.ud_grade_1d(l2_image,l2_image.shape[axis+2]//2,axis=axis+2)
959
-
960
- # Rescale l_image [....,Npix_j1//4,....]
961
- l_image1 = self.smooth_1d(l_image1,axis=axis+1)
962
- l_image1 = self.ud_grade_1d(l_image1,l_image1.shape[axis+1]//2,axis=axis+1)
1378
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
1379
+
1380
+ # Rescale l2_image [....,Npix_j1//4,....,j1,Norient]
1381
+ l2_image = self.smooth_1d(l2_image, axis=axis + 2)
1382
+ l2_image = self.ud_grade_1d(
1383
+ l2_image, l2_image.shape[axis + 2] // 2, axis=axis + 2
1384
+ )
1385
+
1386
+ # Rescale l_image [....,Npix_j1//4,....]
1387
+ l_image1 = self.smooth_1d(l_image1, axis=axis + 1)
1388
+ l_image1 = self.ud_grade_1d(
1389
+ l_image1, l_image1.shape[axis + 1] // 2, axis=axis + 1
1390
+ )
963
1391
  if cross:
964
- l_image2 = self.smooth_1d(l_image2,axis=axis+2)
965
- l_image2 = self.ud_grade_1d(l_image2,l_image2.shape[axis+2]//2,axis=axis+2)
1392
+ l_image2 = self.smooth_1d(l_image2, axis=axis + 2)
1393
+ l_image2 = self.ud_grade_1d(
1394
+ l_image2, l_image2.shape[axis + 2] // 2, axis=axis + 2
1395
+ )
966
1396
 
967
- return(scat1D(p00,s0,s1,s2,s2l,s2j1,s2j2,cross=cross,backend=self.backend))
1397
+ return scat1D(
1398
+ p00, s0, s1, s2, s2l, s2j1, s2j2, cross=cross, backend=self.backend
1399
+ )
968
1400
 
969
- def square(self,x):
1401
+ def square(self, x):
970
1402
  # the abs make the complex value usable for reduce_sum or mean
971
- return scat1D(self.backend.bk_square(self.backend.bk_abs(x.P00)),
972
- self.backend.bk_square(self.backend.bk_abs(x.S0)),
973
- self.backend.bk_square(self.backend.bk_abs(x.S1)),
974
- self.backend.bk_square(self.backend.bk_abs(x.S2)),
975
- self.backend.bk_square(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
976
-
977
- def sqrt(self,x):
1403
+ return scat1D(
1404
+ self.backend.bk_square(self.backend.bk_abs(x.P00)),
1405
+ self.backend.bk_square(self.backend.bk_abs(x.S0)),
1406
+ self.backend.bk_square(self.backend.bk_abs(x.S1)),
1407
+ self.backend.bk_square(self.backend.bk_abs(x.S2)),
1408
+ self.backend.bk_square(self.backend.bk_abs(x.S2L)),
1409
+ x.j1,
1410
+ x.j2,
1411
+ backend=self.backend,
1412
+ )
1413
+
1414
+ def sqrt(self, x):
978
1415
  # the abs make the complex value usable for reduce_sum or mean
979
- return scat1D(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
980
- self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
981
- self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
982
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
983
- self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),x.j1,x.j2,backend=self.backend)
984
-
985
- def reduce_mean(self,x,axis=None):
1416
+ return scat1D(
1417
+ self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1418
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
1419
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
1420
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
1421
+ self.backend.bk_sqrt(self.backend.bk_abs(x.S2L)),
1422
+ x.j1,
1423
+ x.j2,
1424
+ backend=self.backend,
1425
+ )
1426
+
1427
+ def reduce_mean(self, x, axis=None):
986
1428
  if axis is None:
987
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))+ \
988
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))+ \
989
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))+ \
990
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))+ \
991
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
992
-
993
- ntmp=np.array(list(x.P00.shape)).prod()+ \
994
- np.array(list(x.S0.shape)).prod()+ \
995
- np.array(list(x.S1.shape)).prod()+ \
996
- np.array(list(x.S2.shape)).prod()
997
-
998
- return tmp/ntmp
1429
+ tmp = (
1430
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00))
1431
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0))
1432
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1))
1433
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2))
1434
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L))
1435
+ )
1436
+
1437
+ ntmp = (
1438
+ np.array(list(x.P00.shape)).prod()
1439
+ + np.array(list(x.S0.shape)).prod()
1440
+ + np.array(list(x.S1.shape)).prod()
1441
+ + np.array(list(x.S2.shape)).prod()
1442
+ )
1443
+
1444
+ return tmp / ntmp
999
1445
  else:
1000
- tmp=self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00,axis=axis))+ \
1001
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0,axis=axis))+ \
1002
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1,axis=axis))+ \
1003
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2,axis=axis))+ \
1004
- self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L,axis=axis))
1005
-
1006
- ntmp=np.array(list(x.P00.shape)).prod()+ \
1007
- np.array(list(x.S0.shape)).prod()+ \
1008
- np.array(list(x.S1.shape)).prod()+ \
1009
- np.array(list(x.S2.shape)).prod()+ \
1010
- np.array(list(x.S2L.shape)).prod()
1011
-
1012
- return tmp/ntmp
1013
-
1014
- def reduce_sum(self,x,axis=None):
1446
+ tmp = (
1447
+ self.backend.bk_abs(self.backend.bk_reduce_sum(x.P00, axis=axis))
1448
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S0, axis=axis))
1449
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S1, axis=axis))
1450
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2, axis=axis))
1451
+ + self.backend.bk_abs(self.backend.bk_reduce_sum(x.S2L, axis=axis))
1452
+ )
1453
+
1454
+ ntmp = (
1455
+ np.array(list(x.P00.shape)).prod()
1456
+ + np.array(list(x.S0.shape)).prod()
1457
+ + np.array(list(x.S1.shape)).prod()
1458
+ + np.array(list(x.S2.shape)).prod()
1459
+ + np.array(list(x.S2L.shape)).prod()
1460
+ )
1461
+
1462
+ return tmp / ntmp
1463
+
1464
+ def reduce_sum(self, x, axis=None):
1015
1465
  if axis is None:
1016
- return self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))+ \
1017
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))+ \
1018
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))+ \
1019
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))+ \
1020
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1466
+ return (
1467
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.P00))
1468
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
1469
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
1470
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
1471
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2L))
1472
+ )
1021
1473
  else:
1022
- return scat1D(self.backend.bk_reduce_sum(x.P00,axis=axis),
1023
- self.backend.bk_reduce_sum(x.S0,axis=axis),
1024
- self.backend.bk_reduce_sum(x.S1,axis=axis),
1025
- self.backend.bk_reduce_sum(x.S2,axis=axis),
1026
- self.backend.bk_reduce_sum(x.S2L,axis=axis),x.j1,x.j2,backend=self.backend)
1027
-
1028
- def ldiff(self,sig,x):
1029
- return scat1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1030
- x.domult(sig.S0,x.S0)*x.domult(sig.S0,x.S0),
1031
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1032
- x.domult(sig.S2,x.S2)*x.domult(sig.S2,x.S2),
1033
- x.domult(sig.S2L,x.S2L)*x.domult(sig.S2L,x.S2L),x.j1,x.j2,backend=self.backend)
1034
-
1035
- def log(self,x):
1036
- return scat1D(self.backend.bk_log(x.P00),
1037
- self.backend.bk_log(x.S0),
1038
- self.backend.bk_log(x.S1),
1039
- self.backend.bk_log(x.S2),
1040
- self.backend.bk_log(x.S2L),x.j1,x.j2,backend=self.backend)
1041
- def abs(self,x):
1042
- return scat1D(self.backend.bk_abs(x.P00),
1043
- self.backend.bk_abs(x.S0),
1044
- self.backend.bk_abs(x.S1),
1045
- self.backend.bk_abs(x.S2),
1046
- self.backend.bk_abs(x.S2L),x.j1,x.j2,backend=self.backend)
1047
- def inv(self,x):
1048
- return scat1D(1/(x.P00),1/(x.S0),1/(x.S1),1/(x.S2),1/(x.S2L),x.j1,x.j2,backend=self.backend)
1474
+ return scat1D(
1475
+ self.backend.bk_reduce_sum(x.P00, axis=axis),
1476
+ self.backend.bk_reduce_sum(x.S0, axis=axis),
1477
+ self.backend.bk_reduce_sum(x.S1, axis=axis),
1478
+ self.backend.bk_reduce_sum(x.S2, axis=axis),
1479
+ self.backend.bk_reduce_sum(x.S2L, axis=axis),
1480
+ x.j1,
1481
+ x.j2,
1482
+ backend=self.backend,
1483
+ )
1484
+
1485
+ def ldiff(self, sig, x):
1486
+ return scat1D(
1487
+ x.domult(sig.P00, x.P00) * x.domult(sig.P00, x.P00),
1488
+ x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
1489
+ x.domult(sig.S1, x.S1) * x.domult(sig.S1, x.S1),
1490
+ x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
1491
+ x.domult(sig.S2L, x.S2L) * x.domult(sig.S2L, x.S2L),
1492
+ x.j1,
1493
+ x.j2,
1494
+ backend=self.backend,
1495
+ )
1496
+
1497
+ def log(self, x):
1498
+ return scat1D(
1499
+ self.backend.bk_log(x.P00),
1500
+ self.backend.bk_log(x.S0),
1501
+ self.backend.bk_log(x.S1),
1502
+ self.backend.bk_log(x.S2),
1503
+ self.backend.bk_log(x.S2L),
1504
+ x.j1,
1505
+ x.j2,
1506
+ backend=self.backend,
1507
+ )
1508
+
1509
+ def abs(self, x):
1510
+ return scat1D(
1511
+ self.backend.bk_abs(x.P00),
1512
+ self.backend.bk_abs(x.S0),
1513
+ self.backend.bk_abs(x.S1),
1514
+ self.backend.bk_abs(x.S2),
1515
+ self.backend.bk_abs(x.S2L),
1516
+ x.j1,
1517
+ x.j2,
1518
+ backend=self.backend,
1519
+ )
1520
+
1521
+ def inv(self, x):
1522
+ return scat1D(
1523
+ 1 / (x.P00),
1524
+ 1 / (x.S0),
1525
+ 1 / (x.S1),
1526
+ 1 / (x.S2),
1527
+ 1 / (x.S2L),
1528
+ x.j1,
1529
+ x.j2,
1530
+ backend=self.backend,
1531
+ )
1049
1532
 
1050
1533
  def one(self):
1051
- return scat1D(1.0,1.0,1.0,1.0,1.0,[0],[0],backend=self.backend)
1534
+ return scat1D(1.0, 1.0, 1.0, 1.0, 1.0, [0], [0], backend=self.backend)
1052
1535
 
1053
- @tf.function
1054
- def eval_comp_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1536
+ def eval_comp_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1055
1537
 
1056
- res=self.eval(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1057
- return res.P00,res.S0,res.S1,res.S2,res.S2L,res.j1,res.j2
1538
+ res = self.eval_fast(image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off)
1539
+ return res.P00, res.S0, res.S1, res.S2, res.S2L, res.j1, res.j2
1058
1540
 
1059
- def eval_fast(self, image1, image2=None,mask=None,Auto=True,s0_off=1E-6):
1060
- p0,s0,s1,s2,s2l,j1,j2=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,s0_off=s0_off)
1061
- return scat1D(p0,s0,s1,s2,s2l,j1,j2,backend=self.backend)
1062
-
1063
-
1064
-
1541
+ @tf_function
1542
+ def eval_fast(self, image1, image2=None, mask=None, Auto=True, s0_off=1e-6):
1543
+ p0, s0, s1, s2, s2l, j1, j2 = self.eval_comp_fast(
1544
+ image1, image2=image2, mask=mask, Auto=Auto, s0_off=s0_off
1545
+ )
1546
+ return scat1D(p0, s0, s1, s2, s2l, j1, j2, backend=self.backend)