foscat 3.0.8__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov1D.py CHANGED
@@ -1,1474 +1,18 @@
1
- import foscat.FoCUS as FOC
2
- import numpy as np
3
- import foscat.backend as bk
4
- import tensorflow as tf
5
- import pickle
6
- import matplotlib.pyplot as plt
1
+ import foscat.scat_cov as scat
7
2
 
8
- def read(filename):
9
- thescat = scat_cov1D(1, 1, 1)
10
- return thescat.read(filename)
11
-
12
- testwarn=0
13
3
 
14
4
  class scat_cov1D:
15
- def __init__(self, p00, c01, c11, s1=None, c10=None,backend=None):
16
- self.P00 = p00
17
- self.C01 = c01
18
- self.C11 = c11
19
- self.S1 = s1
20
- self.C10 = c10
21
- self.backend = backend
22
- self.idx1 = None
23
- self.idx2 = None
24
-
25
- def numpy(self):
26
- if self.S1 is None:
27
- s1 = None
28
- else:
29
- s1=self.S1.numpy()
30
- if self.C10 is None:
31
- c10 = None
32
- else:
33
- c10=self.C10.numpy()
34
-
35
- return scat_cov1D((self.P00.numpy()),
36
- (self.C01.numpy()),
37
- (self.C11.numpy()),
38
- s1=s1, c10=c10,backend=self.backend)
39
-
40
- def constant(self):
41
-
42
- if self.S1 is None:
43
- s1 = None
44
- else:
45
- s1=self.backend.constant(self.S1)
46
- if self.C10 is None:
47
- c10 = None
48
- else:
49
- c10=self.backend.constant(self.C10)
50
-
51
- return scat_cov1D(self.backend.constant(self.P00),
52
- self.backend.constant(self.C01),
53
- self.backend.constant(self.C11),
54
- s1=s1, c10=c10,backend=self.backend)
55
- def get_S1(self):
56
- return self.S1
57
-
58
- def get_P00(self):
59
- return self.P00
60
-
61
- def reset_P00(self):
62
- self.P00=0*self.P00
63
-
64
- def get_C01(self):
65
- return self.C01
66
-
67
- def get_C10(self):
68
- return self.C10
69
-
70
- def get_C11(self):
71
- return self.C11
72
-
73
- def get_j_idx(self):
74
- shape=list(self.P00.shape)
75
- if len(shape)==3:
76
- nscale=shape[2]
77
- else:
78
- nscale=shape[3]
79
-
80
- n=nscale*(nscale+1)//2
81
- j1=np.zeros([n],dtype='int')
82
- j2=np.zeros([n],dtype='int')
83
- n=0
84
- for i in range(nscale):
85
- for j in range(i+1):
86
- j1[n]=j
87
- j2[n]=i
88
- n=n+1
89
-
90
- return j1,j2
91
-
92
-
93
- def get_jc11_idx(self):
94
- shape=list(self.P00.shape)
95
- nscale=shape[2]
96
- n=nscale*(nscale-1)*(nscale-2)
97
- j1=np.zeros([n*2],dtype='int')
98
- j2=np.zeros([n*2],dtype='int')
99
- j3=np.zeros([n*2],dtype='int')
100
- n=0
101
- for i in range(nscale):
102
- for j in range(i+1):
103
- for k in range(j+1):
104
- j1[n]=k
105
- j2[n]=j
106
- j3[n]=i
107
- n=n+1
108
- return(j1[0:n],j2[0:n],j3[0:n])
109
-
110
- def __add__(self, other):
111
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
112
- isinstance(other, bool) or isinstance(other, scat_cov1D)
113
-
114
- if self.S1 is None:
115
- s1 = None
116
- else:
117
- if isinstance(other, scat_cov1D):
118
- if other.S1 is None:
119
- s1=None
120
- else:
121
- s1 = self.S1 + other.S1
122
- else:
123
- s1 = self.S1 + other
124
-
125
- if self.C10 is None:
126
- c10 = None
127
- else:
128
- if isinstance(other, scat_cov1D):
129
- if other.C10 is None:
130
- c10=None
131
- else:
132
- c10 = self.doadd(self.C10 , other.C10)
133
- else:
134
- c10 = self.C10 + other
135
-
136
- if self.C11 is None:
137
- c11 = None
138
- else:
139
- if isinstance(other, scat_cov1D):
140
- if other.C11 is None:
141
- c11 = None
142
- else:
143
- c11 = self.doadd(self.C11, other.C11 )
144
- else:
145
- c11 = self.C11+other
146
-
147
- if isinstance(other, scat_cov1D):
148
- return scat_cov1D(self.doadd(self.P00,other.P00),
149
- (self.C01 + other.C01),
150
- c11,s1=s1, c10=c10,backend=self.backend)
151
- else:
152
- return scat_cov1D((self.P00 + other),
153
- (self.C01 + other),
154
- c11,s1=s1, c10=c10,backend=self.backend)
155
-
156
-
157
- def relu(self):
158
-
159
- if self.S1 is None:
160
- s1 = None
161
- else:
162
- s1 = self.backend.bk_relu(self.S1)
163
-
164
- if self.C10 is None:
165
- c10 = None
166
- else:
167
- c10 = self.backend.bk_relu(self.c10)
168
-
169
- if self.C11 is None:
170
- c11 = None
171
- else:
172
- c11 = self.backend.bk_relu(self.c11)
173
-
174
- return scat_cov1D(self.backend.bk_relu(self.P00),
175
- self.backend.bk_relu(self.C01),
176
- c11,
177
- s1=s1,
178
- c10=c10,
179
- backend=self.backend)
180
-
181
- def __radd__(self, other):
182
- return self.__add__(other)
183
-
184
- def __truediv__(self, other):
185
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
186
- isinstance(other, bool) or isinstance(other, scat_cov1D)
187
-
188
- if self.S1 is None:
189
- s1 = None
190
- else:
191
- if isinstance(other, scat_cov1D):
192
- if other.S1 is None:
193
- s1 = None
194
- else:
195
- s1 = self.S1 / other.S1
196
- else:
197
- s1 = self.S1 / other
198
-
199
- if self.C10 is None:
200
- c10 = None
201
- else:
202
- if isinstance(other, scat_cov1D):
203
- if other.C10 is None:
204
- c10 = None
205
- else:
206
- c10 = self.dodiv(self.C10 , other.C10)
207
- else:
208
- c10 = self.C10 / other
209
-
210
- if self.C11 is None:
211
- c11 = None
212
- else:
213
- if isinstance(other, scat_cov1D):
214
- if other.C11 is None:
215
- c11 = None
216
- else:
217
- c11 = self.dodiv(self.C11, other.C11 )
218
- else:
219
- c11 = self.C11/other
220
-
221
- if isinstance(other, scat_cov1D):
222
- return scat_cov1D(self.dodiv(self.P00,other.P00),
223
- (self.C01 / other.C01),
224
- c11,s1=s1, c10=c10,backend=self.backend)
225
- else:
226
- return scat_cov1D((self.P00 / other),
227
- (self.C01 / other),
228
- c11,s1=s1, c10=c10,backend=self.backend)
229
-
230
- def __rtruediv__(self, other):
231
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
232
- isinstance(other, bool) or isinstance(other, scat_cov1D)
233
-
234
- if self.S1 is None:
235
- s1 = None
236
- else:
237
- if isinstance(other, scat_cov1D):
238
- s1 = other.S1 / self.S1
239
- else:
240
- s1 = other/self.S1
241
-
242
- if self.C10 is None:
243
- c10 = None
244
- else:
245
- if isinstance(other, scat_cov1D):
246
- c10 = self.dodiv(other.C10 , self.C10)
247
- else:
248
- c10 = other/self.C10
249
-
250
- if self.C11 is None:
251
- c11 = None
252
- else:
253
- if isinstance(other, scat_cov1D):
254
- if other.C11 is None:
255
- c11 = None
256
- else:
257
- c11 = self.dodiv( other.C11,self.C11 )
258
- else:
259
- c11 = other/self.C11
260
-
261
- if isinstance(other, scat_cov1D):
262
- return scat_cov1D(self.dodiv(other.P00,self.P00),
263
- (other.C01 / self.C01),
264
- c11,s1=s1, c10=c10,backend=self.backend)
265
- else:
266
- return scat_cov1D((other/self.P00 ),
267
- (other/self.C01 ),
268
- (other/self.C11 ),
269
- s1=s1, c10=c10,backend=self.backend)
270
-
271
- def __rsub__(self, other):
272
-
273
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
274
- isinstance(other, bool) or isinstance(other, scat_cov1D)
275
-
276
- if self.S1 is None:
277
- s1 = None
278
- else:
279
- if isinstance(other, scat_cov1D):
280
- if other.S1 is None:
281
- s1 = None
282
- else:
283
- s1 = other.S1 - self.S1
284
- else:
285
- s1 = other - self.S1
286
-
287
- if self.C10 is None:
288
- c10 = None
289
- else:
290
- if isinstance(other, scat_cov1D):
291
- if other.C10 is None:
292
- c10 = None
293
- else:
294
- c10 = self.domin(other.C10 , self.C10 )
295
- else:
296
- c10 = other - self.C10
297
-
298
- if self.C11 is None:
299
- c11 = None
300
- else:
301
- if isinstance(other, scat_cov1D):
302
- if other.C11 is None:
303
- c11 = None
304
- else:
305
- c11 = self.domin( other.C11,self.C11 )
306
- else:
307
- c11 = other - self.C11
308
-
309
- if isinstance(other, scat_cov1D):
310
- return scat_cov1D(self.domin(other.P00,self.P00),
311
- (other.C01 - self.C01),
312
- c11,s1=s1, c10=c10,
313
- backend=self.backend)
314
- else:
315
- return scat_cov1D((other-self.P00),
316
- (other-self.C01),
317
- c11,s1=s1, c10=c10,
318
- backend=self.backend)
319
-
320
- def __sub__(self, other):
321
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
322
- isinstance(other, bool) or isinstance(other, scat_cov1D)
323
-
324
- if self.S1 is None:
325
- s1 = None
326
- else:
327
- if isinstance(other, scat_cov1D):
328
- if other.S1 is None:
329
- s1 = None
330
- else:
331
- s1 = self.S1 - other.S1
332
- else:
333
- s1 = self.S1 - other
334
-
335
- if self.C10 is None:
336
- c10 = None
337
- else:
338
- if isinstance(other, scat_cov1D):
339
- if other.C10 is None:
340
- c10 = None
341
- else:
342
- c10 = self.domin(self.C10 , other.C10)
343
- else:
344
- c10 = self.C10 - other
345
-
346
- if self.C11 is None:
347
- c11 = None
348
- else:
349
- if isinstance(other, scat_cov1D):
350
- if other.C11 is None:
351
- c11 = None
352
- else:
353
- c11 = self.domin(self.C11 , other.C11)
354
- else:
355
- c11 = self.C11 - other
356
-
357
- if isinstance(other, scat_cov1D):
358
- return scat_cov1D(self.domin(self.P00,other.P00),
359
- (self.C01 - other.C01),
360
- c11,
361
- s1=s1, c10=c10,backend=self.backend)
362
- else:
363
- return scat_cov1D((self.P00 - other),
364
- (self.C01 - other),
365
- c11,
366
- s1=s1, c10=c10,backend=self.backend)
367
-
368
- def domult(self,x,y):
369
- if x.dtype==y.dtype:
370
- return x*y
371
- if x.dtype=='complex64' or x.dtype=='complex128':
372
-
373
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
374
- else:
375
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
376
-
377
- def dodiv(self,x,y):
378
- if x.dtype==y.dtype:
379
- return x/y
380
- if x.dtype=='complex64' or x.dtype=='complex128':
381
-
382
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
383
- else:
384
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
385
-
386
- def domin(self,x,y):
387
- if x.dtype==y.dtype:
388
- return x-y
389
- if x.dtype=='complex64' or x.dtype=='complex128':
390
-
391
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
392
- else:
393
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
394
-
395
- def doadd(self,x,y):
396
- if x.dtype==y.dtype:
397
- return x+y
398
- if x.dtype=='complex64' or x.dtype=='complex128':
399
-
400
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
401
- else:
402
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
403
-
404
-
405
- def __mul__(self, other):
406
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
407
- isinstance(other, bool) or isinstance(other, scat_cov1D)
408
-
409
- if self.S1 is None:
410
- s1 = None
411
- else:
412
- if isinstance(other, scat_cov1D):
413
- if other.S1 is None:
414
- s1 = None
415
- else:
416
- s1 = self.S1 * other.S1
417
- else:
418
- s1 = self.S1 * other
419
-
420
- if self.C10 is None:
421
- c10 = None
422
- else:
423
- if isinstance(other, scat_cov1D):
424
- if other.C10 is None:
425
- c10 = None
426
- else:
427
- c10 = self.domult(self.C10 , other.C10)
428
- else:
429
- c10 = self.C10 * other
430
-
431
- if self.C11 is None:
432
- c11 = None
433
- else:
434
- if isinstance(other, scat_cov1D):
435
- if other.C11 is None:
436
- c11 = None
437
- else:
438
- c11 = self.domult(self.C11 , other.C11)
439
- else:
440
- c11 = self.C11 * other
441
-
442
- if isinstance(other, scat_cov1D):
443
- return scat_cov1D(self.domult(self.P00,other.P00),
444
- self.domult(self.C01,other.C01),
445
- c11,
446
- s1=s1, c10=c10,backend=self.backend)
447
- else:
448
- return scat_cov1D((self.P00 * other),
449
- (self.C01 * other),
450
- c11,
451
- s1=s1, c10=c10,backend=self.backend)
452
-
453
-
454
- def __rmul__(self, other):
455
- return self.__mul__(other)
456
-
457
- # ---------------------------------------------−---------
458
- def interp(self,nscale,extend=True,constant=False):
459
-
460
- if nscale+2>self.P00.shape[2]:
461
- print('Can not *interp* %d with a statistic described over %d'%(nscale,self.P00.shape[2]))
462
- return scat_cov1D(self.P00,self.C01,self.C11,s1=self.S1,c10=self.C10,backend=self.backend)
463
-
464
- if self.S1 is not None:
465
- s1=self.S1.numpy()
466
- else:
467
- s1=self.S1
468
-
469
- p0=self.P00.numpy()
470
- for k in range(nscale):
471
- if constant:
472
- if self.S1 is not None:
473
- s1[:,:,nscale-1-k,:]=s1[:,:,nscale-k,:]
474
- p0[:,:,nscale-1-k,:]=p0[:,:,nscale-k,:]
475
- else:
476
- if self.S1 is not None:
477
- s1[:,:,nscale-1-k,:]=np.exp(2*np.log(s1[:,:,nscale-k,:])-np.log(s1[:,:,nscale+1-k,:]))
478
- p0[:,:,nscale-1-k,:]=np.exp(2*np.log(p0[:,:,nscale-k,:])-np.log(p0[:,:,nscale+1-k,:]))
479
-
480
- j1,j2=self.get_j_idx()
481
-
482
- if self.C10 is not None:
483
- c10=self.C10.numpy()
484
- else:
485
- c10=self.C10
486
- c01=self.C01.numpy()
487
-
488
- for k in range(nscale):
489
-
490
- for l in range(nscale-k):
491
- i0=np.where((j1==nscale-1-k-l)*(j2==nscale-1-k))[0]
492
- i1=np.where((j1==nscale-1-k-l)*(j2==nscale -k))[0]
493
- i2=np.where((j1==nscale-1-k-l)*(j2==nscale+1-k))[0]
494
- if constant:
495
- c10[:,:,i0]=c10[:,:,i1]
496
- c01[:,:,i0]=c01[:,:,i1]
497
- else:
498
- c10[:,:,i0]=np.exp(2*np.log(c10[:,:,i1])-np.log(c10[:,:,i2]))
499
- c01[:,:,i0]=np.exp(2*np.log(c01[:,:,i1])-np.log(c01[:,:,i2]))
500
-
501
-
502
- c11=self.C11.numpy()
503
- j1,j2,j3=self.get_jc11_idx()
504
-
505
- for k in range(nscale):
506
-
507
- for l in range(nscale-k):
508
- for m in range(nscale-k-l):
509
- i0=np.where((j1==nscale-1-k-l-m)*(j2==nscale-1-k-l)*(j3==nscale-1-k))[0]
510
- i1=np.where((j1==nscale-1-k-l-m)*(j2==nscale-1-k-l)*(j3==nscale -k))[0]
511
- i2=np.where((j1==nscale-1-k-l-m)*(j2==nscale-1-k-l)*(j3==nscale+1-k))[0]
512
- if constant:
513
- c11[:,:,i0]=c11[:,:,i1]
514
- else:
515
- c11[:,:,i0]=np.exp(2*np.log(c11[:,:,i1])-np.log(c11[:,:,i2]))
516
-
517
- if s1 is not None:
518
- s1=self.backend.constant(s1)
519
- if c10 is not None:
520
- c10=self.backend.constant(c10)
521
-
522
- return scat_cov1D(self.backend.constant(p0),self.backend.constant(c01),
523
- self.backend.constant(c11),s1=s1,c10=c10,backend=self.backend)
524
-
525
- def plot(self, name=None, hold=True, color='blue', lw=1, legend=True):
526
-
527
- import matplotlib.pyplot as plt
528
-
529
- if name is None:
530
- name = ''
531
-
532
- j1,j2=self.get_j_idx()
533
-
534
- if hold:
535
- plt.figure(figsize=(16, 8))
536
-
537
- if self.S1 is not None:
538
- plt.subplot(2, 2, 1)
539
- tmp=abs(self.get_np(self.S1))
540
- test=None
541
- for i1 in range(tmp.shape[0]):
542
- for i2 in range(tmp.shape[1]):
543
- if test is None:
544
- test=1
545
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $S_1$' % (name), lw=lw)
546
- else:
547
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
548
- plt.yscale('log')
549
- plt.legend()
550
- plt.ylabel('S1')
551
- plt.xlabel(r'$j_{1}$')
552
-
553
- test=None
554
- plt.subplot(2, 2, 2)
555
- tmp=abs(self.get_np(self.P00))
556
- for i1 in range(tmp.shape[0]):
557
- for i2 in range(tmp.shape[0]):
558
- if test is None:
559
- test=1
560
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
561
- else:
562
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
563
- plt.yscale('log')
564
- plt.ylabel('P00')
565
- plt.xlabel(r'$j_{1}$')
566
- plt.legend()
567
-
568
- ax1=plt.subplot(2, 2, 3)
569
- ax2 = ax1.twiny()
570
- n=0
571
- tmp=abs(self.get_np(self.C01))
572
- lname=r'%s $C_{01}$' % (name)
573
- ax1.set_ylabel(r'$C_{01}$')
574
- if self.C10 is not None:
575
- tmp=abs(self.get_np(self.C01))
576
- lname=r'%s $C_{10}$' % (name)
577
- ax1.set_ylabel(r'$C_{10}$')
578
- test=None
579
- tabx=[]
580
- tabnx=[]
581
- tab2x=[]
582
- tab2nx=[]
583
-
584
- for i0 in range(tmp.shape[0]):
585
- for i1 in range(tmp.shape[1]):
586
- for i2 in range(j1.max()+1):
587
- if j2[j1==i2].shape[0]==1:
588
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2],'.', \
589
- color=color, lw=lw)
590
- else:
591
- if legend and test is None:
592
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
593
- color=color, label=lname, lw=lw)
594
- test=1
595
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
596
- color=color, lw=lw)
597
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
598
- tabx=tabx+[k+n for k in j2[j1==i2]]
599
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
600
- tab2nx=tab2nx+['%d'%(i2)]
601
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
602
- n=n+j2[j1==i2].shape[0]-1
603
- plt.yscale('log')
604
- ax1.set_xlim(0,n+2)
605
- ax1.set_xticks(tabx)
606
- ax1.set_xticklabels(tabnx,fontsize=6)
607
- ax1.set_xlabel(r"$j_{2}$",fontsize=6)
608
-
609
- # Move twinned axis ticks and label from top to bottom
610
- ax2.xaxis.set_ticks_position("bottom")
611
- ax2.xaxis.set_label_position("bottom")
612
-
613
- # Offset the twin axis below the host
614
- ax2.spines["bottom"].set_position(("axes", -0.15))
615
-
616
- # Turn on the frame for the twin axis, but then hide all
617
- # but the bottom spine
618
- ax2.set_frame_on(True)
619
- ax2.patch.set_visible(False)
620
-
621
- for sp in ax2.spines.values():
622
- sp.set_visible(False)
623
- ax2.spines["bottom"].set_visible(True)
624
- ax2.set_xlim(0,n+2)
625
- ax2.set_xticks(tab2x)
626
- ax2.set_xticklabels(tab2nx,fontsize=6)
627
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
628
- ax1.legend(frameon=0)
629
-
630
- ax1=plt.subplot(2, 2, 4)
631
- j1,j2,j3=self.get_jc11_idx()
632
- ax2 = ax1.twiny()
633
- n=1
634
- tmp=abs(self.get_np(self.C11))
635
- lname=r'%s $C_{11}$' % (name)
636
- test=None
637
- tabx=[]
638
- tabnx=[]
639
- tab2x=[]
640
- tab2nx=[]
641
- for i0 in range(tmp.shape[0]):
642
- for i1 in range(tmp.shape[1]):
643
- for i2 in range(j1.max()+1):
644
- nprev=n
645
- for i2b in range(j2[j1==i2].max()+1):
646
- idx=np.where((j1==i2)*(j2==i2b))[0]
647
- if len(idx)==1:
648
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx],'.', \
649
- color=color, lw=lw)
650
- else:
651
- if legend and test is None:
652
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx], \
653
- color=color, label=lname, lw=lw)
654
- test=1
655
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx], \
656
- color=color, lw=lw)
657
- tabnx=tabnx+[r'%d,%d'%(j2[k],j3[k]) for k in idx]
658
- tabx=tabx+[k+n for k in range(len(idx))]
659
- n=n+idx.shape[0]
660
- tab2x=tab2x+[(n+nprev-1)/2]
661
- tab2nx=tab2nx+['%d'%(i2)]
662
- ax1.axvline(n-0.5,ls=':',color='gray')
663
- plt.yscale('log')
664
- ax1.set_ylabel(r'$C_{11}$')
665
- ax1.set_xticks(tabx)
666
- ax1.set_xticklabels(tabnx,fontsize=6)
667
- ax1.set_xlabel(r"$j_{2},j_{3}$",fontsize=6)
668
- ax1.set_xlim(0,n)
669
-
670
- # Move twinned axis ticks and label from top to bottom
671
- ax2.xaxis.set_ticks_position("bottom")
672
- ax2.xaxis.set_label_position("bottom")
673
-
674
- # Offset the twin axis below the host
675
- ax2.spines["bottom"].set_position(("axes", -0.15))
676
-
677
- # Turn on the frame for the twin axis, but then hide all
678
- # but the bottom spine
679
- ax2.set_frame_on(True)
680
- ax2.patch.set_visible(False)
681
-
682
- for sp in ax2.spines.values():
683
- sp.set_visible(False)
684
- ax2.spines["bottom"].set_visible(True)
685
- ax2.set_xlim(0,n)
686
- ax2.set_xticks(tab2x)
687
- ax2.set_xticklabels(tab2nx,fontsize=6)
688
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
689
- ax1.legend(frameon=0)
690
-
691
- def get_np(self, x):
692
- if x is not None:
693
- if isinstance(x, np.ndarray):
694
- return x
695
- else:
696
- return x.numpy()
697
- else:
698
- return None
699
-
700
- def save(self, filename):
701
-
702
- outlist=[self.get_np(self.S1), \
703
- self.get_np(self.C10), \
704
- self.get_np(self.C01), \
705
- self.get_np(self.C11), \
706
- self.get_np(self.P00)]
707
-
708
- myout=open("%s.pkl"%(filename),"wb")
709
- pickle.dump(outlist,myout)
710
- myout.close()
711
-
712
- def read(self, filename):
713
-
714
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
715
-
716
- return scat_cov1D(outlist[4], outlist[2], outlist[3], \
717
- s1=outlist[0], c10=outlist[1],backend=self.backend)
718
-
719
- def std(self):
720
- if self.S1 is not None: # Auto
721
- return np.sqrt(((abs(self.get_np(self.S1)).std()) ** 2 +
722
- (abs(self.get_np(self.C01)).std()) ** 2 +
723
- (abs(self.get_np(self.C11)).std()) ** 2 +
724
- (abs(self.get_np(self.P00)).std()) ** 2 ) / 4)
725
- else: # Cross
726
- return np.sqrt(((abs(self.get_np(self.C01)).std()) ** 2 +
727
- (abs(self.get_np(self.C10)).std()) ** 2 +
728
- (abs(self.get_np(self.C11)).std()) ** 2 +
729
- (abs(self.get_np(self.P00)).std()) ** 2) / 4)
730
-
731
- def mean(self):
732
- if self.S1 is not None: # Auto
733
- return (abs(self.get_np(self.S1)).mean() +
734
- abs(self.get_np(self.C01)).mean() +
735
- abs(self.get_np(self.C11)).mean() +
736
- abs(self.get_np(self.P00)).mean()) / 4
737
- else: # Cross
738
- return (abs(self.get_np(self.C01)).mean() +
739
- abs(self.get_np(self.C10)).mean() +
740
- abs(self.get_np(self.C11)).mean() +
741
- abs(self.get_np(self.P00)).mean()) / 4
742
-
743
- def initdx(self,norient):
744
- idx1=np.zeros([norient*norient],dtype='int')
745
- for i in range(norient):
746
- idx1[i*norient:(i+1)*norient]=(np.arange(norient)+i)%norient+i*norient
747
-
748
- idx2=np.zeros([norient*norient*norient],dtype='int')
749
- for i in range(norient):
750
- for j in range(norient):
751
- idx2[i*norient*norient+j*norient:i*norient*norient+(j+1)*norient]= \
752
- ((np.arange(norient)+i)%norient)*norient \
753
- +(np.arange(norient)+i+j)%norient+np.arange(norient)*norient*norient
754
- self.idx1=self.backend.constant(idx1)
755
- self.idx2=self.backend.constant(idx2)
756
-
757
- def get_nscale(self):
758
- return self.P00.shape[2]
759
-
760
- def get_norient(self):
761
- return self.P00.shape[3]
762
-
763
- def add_data_from_log_slope(self,y,n,ds=3):
764
- if len(y)<ds:
765
- if len(y)==1:
766
- return(np.repeat(y[0],n))
767
- if len(y)==2:
768
- a=np.polyfit(np.arange(2),np.log(y[0:2]),1)
769
- else:
770
- a=np.polyfit(np.arange(ds),np.log(y[0:ds]),1)
771
- return np.exp((np.arange(n)-1-n)*a[0]+a[1])
772
-
773
- def add_data_from_slope(self,y,n,ds=3):
774
- if len(y)<ds:
775
- if len(y)==1:
776
- return(np.repeat(y[0],n))
777
- if len(y)==2:
778
- a=np.polyfit(np.arange(2),y[0:2],1)
779
- else:
780
- a=np.polyfit(np.arange(ds),y[0:ds],1)
781
- return (np.arange(n)-1-n)*a[0]+a[1]
782
-
783
- def up_grade(self,nscale,ds=3):
784
- noff=nscale-self.P00.shape[2]
785
- if noff==0:
786
- return scat_cov1D((self.P00),
787
- (self.C01),
788
- (self.C11),
789
- s1=self.S1,
790
- c10=self.C10,backend=self.backend)
791
-
792
- inscale=self.P00.shape[2]
793
- p00=np.zeros([self.P00.shape[0],self.P00.shape[1],nscale,self.P00.shape[3]],dtype='complex')
794
- p00[:,:,noff:,:]=self.P00.numpy()
795
- for i in range(self.P00.shape[0]):
796
- for j in range(self.P00.shape[1]):
797
- for k in range(self.P00.shape[3]):
798
- p00[i,j,0:noff,k]=self.add_data_from_log_slope(p00[i,j,noff:,k],noff,ds=ds)
799
-
800
- s1=np.zeros([self.S1.shape[0],self.S1.shape[1],nscale,self.S1.shape[3]])
801
- s1[:,:,noff:,:]=self.S1.numpy()
802
- for i in range(self.S1.shape[0]):
803
- for j in range(self.S1.shape[1]):
804
- for k in range(self.S1.shape[3]):
805
- s1[i,j,0:noff,k]=self.add_data_from_log_slope(s1[i,j,noff:,k],noff,ds=ds)
806
-
807
- nout=0
808
- for i in range(1,nscale):
809
- nout=nout+i
810
-
811
- c01=np.zeros([self.C01.shape[0],self.C01.shape[1], \
812
- nout,self.C01.shape[3],self.C01.shape[4]],dtype='complex')
813
-
814
- jo1=np.zeros([nout])
815
- jo2=np.zeros([nout])
816
-
817
- n=0
818
- for i in range(1,nscale):
819
- jo1[n:n+i]=np.arange(i)
820
- jo2[n:n+i]=i
821
- n=n+i
822
-
823
- j1=np.zeros([self.C01.shape[2]])
824
- j2=np.zeros([self.C01.shape[2]])
825
-
826
- n=0
827
- for i in range(1,self.P00.shape[2]):
828
- j1[n:n+i]=np.arange(i)
829
- j2[n:n+i]=i
830
- n=n+i
831
-
832
- for i in range(self.C01.shape[0]):
833
- for j in range(self.C01.shape[1]):
834
- for k in range(self.C01.shape[3]):
835
- for l in range(self.C01.shape[4]):
836
- for ij in range(noff+1,nscale):
837
- idx=np.where(jo2==ij)[0]
838
- c01[i,j,idx[noff:],k,l]=self.C01.numpy()[i,j,j2==ij-noff,k,l]
839
- c01[i,j,idx[:noff],k,l]=self.add_data_from_slope(self.C01.numpy()[i,j,j2==ij-noff,k,l],noff,ds=ds)
840
-
841
- for ij in range(nscale):
842
- idx=np.where(jo1==ij)[0]
843
- if idx.shape[0]>noff:
844
- c01[i,j,idx[:noff],k,l]=self.add_data_from_slope(c01[i,j,idx[noff:],k,l],noff,ds=ds)
845
- else:
846
- c01[i,j,idx,k,l]=np.mean(c01[i,j,jo1==ij-1,k,l])
847
-
848
-
849
- nout=0
850
- for j3 in range(nscale):
851
- for j2 in range(0,j3):
852
- for j1 in range(0,j2):
853
- nout=nout+1
854
-
855
- c11=np.zeros([self.C11.shape[0],self.C11.shape[1], \
856
- nout,self.C11.shape[3], \
857
- self.C11.shape[4],self.C11.shape[5]],dtype='complex')
858
-
859
- jo1=np.zeros([nout])
860
- jo2=np.zeros([nout])
861
- jo3=np.zeros([nout])
862
-
863
- nout=0
864
- for j3 in range(nscale):
865
- for j2 in range(0,j3):
866
- for j1 in range(0,j2):
867
- jo1[nout]=j1
868
- jo2[nout]=j2
869
- jo3[nout]=j3
870
- nout=nout+1
871
-
872
- ncross=self.C11.shape[2]
873
- jj1=np.zeros([ncross])
874
- jj2=np.zeros([ncross])
875
- jj3=np.zeros([ncross])
876
-
877
- n=0
878
- for j3 in range(inscale):
879
- for j2 in range(0,j3):
880
- for j1 in range(0,j2):
881
- jj1[n]=j1
882
- jj2[n]=j2
883
- jj3[n]=j3
884
- n=n+1
885
-
886
- n=0
887
- for j3 in range(nscale):
888
- for j2 in range(j3):
889
- idx=np.where((jj3==j3)*(jj2==j2))[0]
890
- if idx.shape[0]>0:
891
- idx2=np.where((jo3==j3+noff)*(jo2==j2+noff))[0]
892
- for i in range(self.C11.shape[0]):
893
- for j in range(self.C11.shape[1]):
894
- for k in range(self.C11.shape[3]):
895
- for l in range(self.C11.shape[4]):
896
- for m in range(self.C11.shape[5]):
897
- c11[i,j,idx2[noff:],k,l,m]=self.C11.numpy()[i,j,idx,k,l,m]
898
- c11[i,j,idx2[:noff],k,l,m]=self.add_data_from_log_slope(self.C11.numpy()[i,j,idx,k,l,m],noff,ds=ds)
899
-
900
- idx=np.where(abs(c11[0,0,:,0,0,0])==0)[0]
901
- for iii in idx:
902
- iii1=np.where((jo1==jo1[iii]+1)*(jo2==jo2[iii]+1)*(jo3==jo3[iii]+1))[0]
903
- iii2=np.where((jo1==jo1[iii]+2)*(jo2==jo2[iii]+2)*(jo3==jo3[iii]+2))[0]
904
- if iii2.shape[0]>0:
905
- for i in range(self.C11.shape[0]):
906
- for j in range(self.C11.shape[1]):
907
- for k in range(self.C11.shape[3]):
908
- for l in range(self.C11.shape[4]):
909
- for m in range(self.C11.shape[5]):
910
- c11[i,j,iii,k,l,m]=self.add_data_from_slope(c11[i,j,[iii1,iii2],k,l,m],1,ds=2)[0]
911
-
912
- idx=np.where(abs(c11[0,0,:,0,0,0])==0)[0]
913
- for iii in idx:
914
- iii1=np.where((jo1==jo1[iii])*(jo2==jo2[iii])*(jo3==jo3[iii]-1))[0]
915
- iii2=np.where((jo1==jo1[iii])*(jo2==jo2[iii])*(jo3==jo3[iii]-2))[0]
916
- if iii2.shape[0]>0:
917
- for i in range(self.C11.shape[0]):
918
- for j in range(self.C11.shape[1]):
919
- for k in range(self.C11.shape[3]):
920
- for l in range(self.C11.shape[4]):
921
- for m in range(self.C11.shape[5]):
922
- c11[i,j,iii,k,l,m]=self.add_data_from_slope(c11[i,j,[iii1,iii2],k,l,m],1,ds=2)[0]
923
-
924
- return scat_cov1D( (p00),
925
- (c01),
926
- (c11),
927
- s1=(s1),backend=self.backend)
928
-
929
-
930
-
931
- class funct(FOC.FoCUS):
932
-
933
- def fill(self,im,nullval=0):
934
- return self.fill_1d(im,nullval=nullval)
935
-
936
- def ud_grade(self,im,nout,axis=0):
937
- return self.ud_grade_1d(im,nout,axis=axis)
938
-
939
- def up_grade(self,im,nout,axis=0):
940
- return self.up_grade_1d(im,nout,axis=axis)
941
-
942
- def smooth(self,data,axis=0):
943
- return self.smooth_1d(data,axis=axis)
944
-
945
- def convol(self,data,axis=0):
946
- return self.convol_1d(data,axis=axis)
947
-
948
- def eval(self, image1, image2=None, mask=None, Auto=True):
949
- """
950
- Calculates the scattering correlations for a batch of images. Mean are done over pixels.
951
- mean of modulus:
952
- S1 = <|I * Psi_j3|>
953
- Normalization : take the log
954
- power spectrum:
955
- P00 = <|I * Psi_j3|^2>
956
- Normalization : take the log
957
- orig. x modulus:
958
- C01 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
959
- Normalization : divide by (P00_j2 * P00_j3)^0.5
960
- modulus x modulus:
961
- C11 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
962
- Normalization : divide by (P00_j1 * P00_j2)^0.5
963
- Parameters
964
- ----------
965
- image1: tensor
966
- Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
967
- image2: tensor
968
- Second image. If not None, we compute cross-scattering covariance coefficients.
969
- mask:
970
- norm: None or str
971
- If None no normalization is applied, if 'auto' normalize by the reference P00,
972
- if 'self' normalize by the current P00.
973
- all_cross: False or True
974
- If False compute all the coefficient even the Imaginary part,
975
- If True return only the terms computable in the auto case.
976
- Returns
977
- -------
978
- S1, P00, C01, C11 normalized
979
- """
980
-
981
- norm=None
982
- axis=0
983
-
984
- # Check input consistency
985
- if image2 is not None:
986
- if list(image1.shape)!=list(image2.shape):
987
- print('The two input image should have the same size to eval Scattering Covariance')
988
- exit(0)
989
- if mask is not None:
990
- if list(image1.shape)!=list(mask.shape)[1:]:
991
- print('The mask should have the same size ',mask.shape,'than the input image ',image1.shape,'to eval Scattering Covariance')
992
- exit(0)
993
-
994
- ### AUTO OR CROSS
995
- cross = False
996
- if image2 is not None:
997
- cross = True
998
- all_cross=Auto
999
- else:
1000
- all_cross=False
1001
-
1002
- ### PARAMETERS
1003
- # determine jmax and nside corresponding to the input map
1004
- im_shape = image1.shape
1005
-
1006
- npix=im_shape[axis]
1007
-
1008
- J = int(np.log(npix) / np.log(2)) # Number of j scales
1009
- Jmax = J - self.OSTEP # Number of steps for the loop on scales
1010
-
1011
- ### LOCAL VARIABLES (IMAGES and MASK)
1012
- if len(image1.shape) == 1:
1013
- I1 = self.backend.bk_cast(self.backend.bk_expand_dims(image1, 0)) # Local image1 [Nbatch, Npix]
1014
- if cross:
1015
- I2 = self.backend.bk_cast(self.backend.bk_expand_dims(image2, 0)) # Local image2 [Nbatch, Npix]
1016
- else:
1017
- I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
1018
- if cross:
1019
- I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
1020
-
1021
- if mask is None:
1022
- vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
1023
- else:
1024
- vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
1025
-
1026
- if self.KERNELSZ > 3:
1027
- # if the kernel size is bigger than 3 increase the binning before smoothing
1028
- I1 = self.up_grade_1d(I1, npix * 2, axis=axis+1)
1029
- vmask = self.up_grade_1d(vmask, npix * 2, axis=axis+1)
1030
- if cross:
1031
- I2 = self.up_grade_1d(I2, npix * 2, axis=axis+1)
1032
-
1033
- # Normalize the masks because they have different pixel numbers
1034
- # vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
1035
-
1036
- ### INITIALIZATION
1037
- # Coefficients
1038
- S1, P00, C01, C11, C10 = None, None, None, None, None
1039
-
1040
- # Dictionaries for C01 computation
1041
- M1_dic = {} # M stands for Module M1 = |I1 * Psi|
1042
- if cross:
1043
- M2_dic = {}
1044
-
1045
- # P00 for normalization
1046
- cond_init_P1_dic = ((norm == 'self') or ((norm == 'auto') and (self.P1_dic is None)))
1047
- if norm is None:
1048
- pass
1049
- elif cond_init_P1_dic:
1050
- P1_dic = {}
1051
- if cross:
1052
- P2_dic = {}
1053
- elif (norm == 'auto') and (self.P1_dic is not None):
1054
- P1_dic = self.P1_dic
1055
- if cross:
1056
- P2_dic = self.P2_dic
1057
-
1058
-
1059
- #### COMPUTE S1, P00, C01 and C11
1060
- nside_j3 = npix # NSIDE start (nside_j3 = nside / 2^j3)
1061
-
1062
- for j3 in range(Jmax):
1063
-
1064
- ####### S1 and P00
1065
- ### Make the convolution I1 * Psi_j3
1066
- conv1 = self.convol_1d(I1, axis=axis+1) # [Nbatch, Npix_j3]
1067
- ### Take the module M1 = |I1 * Psi_j3|
1068
- M1_square = conv1*self.backend.bk_conjugate(conv1) # [Nbatch, Npix_j3]
1069
- M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3]
1070
- # Store M1_j3 in a dictionary
1071
- M1_dic[j3] = M1
1072
-
1073
- if not cross: # Auto
1074
- M1_square=self.backend.bk_real(M1_square)
1075
-
1076
- ### P00_auto = < M1^2 >_pix
1077
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1078
- if M1_square.dtype=='complex64' or M1_square.dtype=='complex128':
1079
- p00 = self.backend.bk_reduce_sum(M1_square[:, None, :]*self.backend.bk_complex(vmask[None,:, :],0*vmask[None,:, :]), axis=2)
1080
- else:
1081
- p00 = self.backend.bk_reduce_sum(M1_square[:, None, :]*vmask[None,:, :], axis=2)
1082
-
1083
- if cond_init_P1_dic:
1084
- # We fill P1_dic with P00 for normalisation of C01 and C11
1085
- P1_dic[j3] = p00 # [Nbatch, Nmask, Norient3]
1086
- if norm == 'auto': # Normalize P00
1087
- p00 /= P1_dic[j3]
1088
-
1089
- # We store P00_auto to return it [Nbatch, Nmask, NP00]
1090
- if P00 is None:
1091
- P00 = p00[:, :, None] # Add a dimension for NP00
1092
- else:
1093
- P00 = self.backend.bk_concat([P00, p00[:, :, None]], axis=axis+2)
1094
-
1095
- #### S1_auto computation
1096
- ### Image 1 : S1 = < M1 >_pix
1097
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1098
- if M1.dtype=='complex64' or M1.dtype=='complex128':
1099
- s1 = self.backend.bk_reduce_sum(M1[:, None, :]*self.backend.bk_complex(vmask[None,:, :],0*vmask[None,:, :]), axis=2)
1100
- else:
1101
- s1 = self.backend.bk_reduce_sum(M1[:, None, :]*vmask[None,:, :], axis=2)
1102
-
1103
- ### Normalize S1
1104
- if norm is not None:
1105
- s1 /= (P1_dic[j3]) ** 0.5
1106
- ### We store S1 for image1 [Nbatch, Nmask, NS1]
1107
- if S1 is None:
1108
- S1 = s1[:, :, None] # Add a dimension for NS1
1109
- else:
1110
- S1 = self.backend.bk_concat([S1, s1[:, :, None]], axis=axis+2)
1111
-
1112
- else: # Cross
1113
- ### Make the convolution I2 * Psi_j3
1114
- conv2 = self.convol_1d(I2, axis=axis+1) # [Nbatch, Npix_j3, Norient3]
1115
- ### Take the module M2 = |I2 * Psi_j3|
1116
- M2_square = conv2*self.backend.bk_conjugate(conv2) # [Nbatch, Npix_j3, Norient3]
1117
- M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
1118
- # Store M2_j3 in a dictionary
1119
- M2_dic[j3] = M2
1120
-
1121
- ### P00_auto = < M2^2 >_pix
1122
- # Not returned, only for normalization
1123
- if cond_init_P1_dic:
1124
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1125
- p1 = self.backend.bk_reduce_sum(M1_square* vmask, axis=axis+1) # [Nbatch, Nmask, Norient3]
1126
- p2 = self.backend.bk_reduce_sum(M2_square* vmask, axis=axis+1) # [Nbatch, Nmask, Norient3]
1127
- # We fill P1_dic with P00 for normalisation of C01 and C11
1128
- P1_dic[j3] = p1 # [Nbatch, Nmask, Norient3]
1129
- P2_dic[j3] = p2 # [Nbatch, Nmask, Norient3]
1130
-
1131
- ### P00_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
1132
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
1133
- p00 = conv1 * self.backend.bk_conjugate(conv2)
1134
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1135
- p00 = self.backend.bk_reduce_sum(p00*vmask, axis=1)
1136
- tmp = self.backend.bk_L1(p00) # [Nbatch, Npix_j3, Norient3]
1137
-
1138
- ### Normalize P00_cross
1139
- if norm == 'auto':
1140
- p00 /= (P1_dic[j3] * P2_dic[j3])**0.5
1141
-
1142
- ### Store P00_cross as complex [Nbatch, Nmask, NP00, Norient3]
1143
- if not all_cross:
1144
- p00=self.backend.bk_real(p00)
1145
-
1146
- if P00 is None:
1147
- P00 = p00[:,:,None,:] # Add a dimension for NP00
1148
- else:
1149
- P00 = self.backend.bk_concat([P00, p00[:,:,None,:]], axis=2)
1150
-
1151
- #### S1_auto computation
1152
- ### Image 1 : S1 = < M1 >_pix
1153
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1154
- if tmp.dtype=='complex64' or tmp.dtype=='complex128':
1155
- s1 = self.backend.bk_reduce_sum(tmp[:, None, :]*self.backend.bk_complex(vmask[None,:, :],0*vmask[None,:, :]), axis=2)
1156
- else:
1157
- s1 = self.backend.bk_reduce_sum(tmp[:, None, :]*vmask[None,:, :], axis=2)
1158
-
1159
- ### Normalize S1
1160
- if norm is not None:
1161
- s1 /= (P1_dic[j3]) ** 0.5
1162
- ### We store S1 for image1 [Nbatch, Nmask, NS1]
1163
- if S1 is None:
1164
- S1 = s1[:, :, None] # Add a dimension for NS1
1165
- else:
1166
- S1 = self.backend.bk_concat([S1, s1[:, :, None]], axis=axis+2)
1167
-
1168
- # Initialize dictionaries for |I1*Psi_j| * Psi_j3
1169
- M1convPsi_dic = {}
1170
- if cross:
1171
- # Initialize dictionaries for |I2*Psi_j| * Psi_j3
1172
- M2convPsi_dic = {}
1173
-
1174
- ###### C01
1175
- for j2 in range(0, j3+1): # j2 <= j3
1176
- ### C01_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
1177
- if not cross:
1178
- c01 = self._compute_C01(j2,
1179
- conv1,
1180
- vmask,
1181
- M1_dic,
1182
- M1convPsi_dic) # [Nbatch, Nmask, Norient3, Norient2]
1183
- ### Normalize C01 with P00_j [Nbatch, Nmask, Norient_j]
1184
- if norm is not None:
1185
- c01 /= (P1_dic[j2][:, :, None, :] *
1186
- P1_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
1187
-
1188
- ### Store C01 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
1189
- if C01 is None:
1190
- C01 = c01[:,:,None] # Add a dimension for NC01
1191
- else:
1192
- C01 = self.backend.bk_concat([C01, c01[:, :, None]],
1193
- axis=2) # Add a dimension for NC01
1194
-
1195
- ### C01_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
1196
- ### C10_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
1197
- else:
1198
- c01 = self._compute_C01(j2,
1199
- conv1,
1200
- vmask,
1201
- M2_dic,
1202
- M2convPsi_dic)
1203
- c10 = self._compute_C01(j2,
1204
- conv2,
1205
- vmask,
1206
- M1_dic,
1207
- M1convPsi_dic)
1208
-
1209
- ### Normalize C01 and C10 with P00_j [Nbatch, Nmask, Norient_j]
1210
- if norm is not None:
1211
- c01 /= (P2_dic[j2][:, :, None, :] *
1212
- P1_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
1213
- c10 /= (P1_dic[j2][:, :, None, :] *
1214
- P2_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
1215
-
1216
- ### Store C01 and C10 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
1217
- if C01 is None:
1218
- C01 = c01[:, :, None] # Add a dimension for NC01
1219
- else:
1220
- C01 = self.backend.bk_concat([C01,c01[:, :, None]],axis=2) # Add a dimension for NC01
1221
- if C10 is None:
1222
- C10 = c10[:, :, None] # Add a dimension for NC01
1223
- else:
1224
- C10 = self.backend.bk_concat([C10,c10[:, :, None]], axis=2) # Add a dimension for NC01
1225
-
1226
-
1227
-
1228
- ##### C11
1229
- for j1 in range(0, j2+1): # j1 <= j2
1230
- ### C11_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
1231
- if not cross:
1232
- c11 = self._compute_C11(j1, j2, vmask,
1233
- M1convPsi_dic,
1234
- M2convPsi_dic=None) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1235
- ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
1236
- if norm is not None:
1237
- c11 /= (P1_dic[j1][:, :, None, None, :] *
1238
- P1_dic[j2][:, :, None, :,
1239
- None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1240
- ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
1241
- if C11 is None:
1242
- C11 = c11[:, :, None] # Add a dimension for NC11
1243
- else:
1244
- C11 = self.backend.bk_concat([C11,c11[:, :, None]],
1245
- axis=2) # Add a dimension for NC11
1246
-
1247
- ### C11_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
1248
- else:
1249
- c11 = self._compute_C11(j1, j2, vmask,
1250
- M1convPsi_dic,
1251
- M2convPsi_dic=M2convPsi_dic) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1252
- ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
1253
- if norm is not None:
1254
- c11 /= (P1_dic[j1][:, :, None, None, :] *
1255
- P2_dic[j2][:, :, None, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1256
- ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
1257
- if C11 is None:
1258
- C11 = c11[:, :, None] # Add a dimension for NC11
1259
- else:
1260
- C11 = self.backend.bk_concat([C11,c11[:, :, None]],
1261
- axis=2) # Add a dimension for NC11
1262
-
1263
- ###### Reshape for next iteration on j3
1264
- ### Image I1,
1265
- # downscale the I1 [Nbatch, Npix_j3]
1266
- if j3 != Jmax - 1:
1267
- I1_smooth = self.smooth_1d(I1, axis=axis+1)
1268
- I1 = self.ud_grade_1d(I1_smooth,I1.shape[axis+1]//2, axis=axis+1)
1269
-
1270
- ### Image I2
1271
- if cross:
1272
- I2_smooth = self.smooth_1d(I2, axis=axis+2)
1273
- I2 = self.ud_grade_1d(I2_smooth,I2.shape[axis+1]//2, axis=axis+1)
1274
-
1275
- ### Modules
1276
- for j2 in range(0, j3 + 1): # j2 =< j3
1277
- ### Dictionary M1_dic[j2]
1278
- M1_smooth = self.smooth_1d(M1_dic[j2], axis=axis+1) # [Nbatch, Npix_j3, Norient3]
1279
- M1_dic[j2] = self.ud_grade_1d(M1_smooth,M1_dic[j2].shape[axis+1]//2, axis=axis+1) # [Nbatch, Npix_j3, Norient3]
1280
-
1281
- ### Dictionary M2_dic[j2]
1282
- if cross:
1283
- M2_smooth = self.smooth_1d(M2_dic[j2], axis=axis+2) # [Nbatch, Npix_j3, Norient3]
1284
- M2_dic[j2] = self.ud_grade_1d(M2_smooth,M2_dic[j2].shape[axis+1]//2, axis=axis+2) # [Nbatch, Npix_j3, Norient3]
1285
- ### Mask
1286
- vmask = self.ud_grade_1d(vmask,vmask.shape[axis+1]//2, axis=1)
1287
-
1288
- if self.mask_thres is not None:
1289
- vmask = self.backend.bk_threshold(vmask,self.mask_thres)
1290
-
1291
- ### NSIDE_j3
1292
- nside_j3 = nside_j3 // 2
1293
-
1294
- ### Store P1_dic and P2_dic in self
1295
- if (norm == 'auto') and (self.P1_dic is None):
1296
- self.P1_dic = P1_dic
1297
- if cross:
1298
- self.P2_dic = P2_dic
1299
-
1300
- if not cross:
1301
- return scat_cov1D(P00, C01, C11, s1=S1,backend=self.backend)
1302
- else:
1303
- return scat_cov1D(P00, C01, C11, c10=C10,backend=self.backend)
1304
-
1305
- def clean_norm(self):
1306
- self.P1_dic = None
1307
- self.P2_dic = None
1308
- return
1309
-
1310
- def _compute_C01(self, j2, conv,
1311
- vmask, M_dic,
1312
- MconvPsi_dic):
1313
- """
1314
- Compute the C01 coefficients (auto or cross)
1315
- C01 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
1316
- Parameters
1317
- ----------
1318
- Returns
1319
- -------
1320
- cc01, sc01: real and imag parts of C01 coeff
1321
- """
1322
- ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
1323
- # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
1324
- MconvPsi = self.convol_1d(M_dic[j2], axis=1) # [Nbatch, Npix_j3, Norient3, Norient2]
1325
-
1326
- # Store it so we can use it in C11 computation
1327
- MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
1328
-
1329
- ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
1330
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
1331
- # cconv, sconv are [Nbatch, Npix_j3, Norient3]
1332
- c01 = conv * self.backend.bk_conjugate(MconvPsi) # [Nbatch, Npix_j3, Norient3, Norient2]
1333
-
1334
- ### Apply the mask [Nmask, Npix_j3] and sum over pixels
1335
- c01 = self.masked_mean(c01, vmask, axis=1,rank=j2) # [Nbatch, Nmask, Norient3, Norient2]
1336
-
1337
- return c01
1338
-
1339
- def _compute_C11(self, j1, j2, vmask,
1340
- M1convPsi_dic,
1341
- M2convPsi_dic=None):
1342
- #### Simplify notations
1343
- M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
1344
-
1345
- # Auto or Cross coefficients
1346
- if M2convPsi_dic is None: # Auto
1347
- M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
1348
- else: # Cross
1349
- M2 = M2convPsi_dic[j2]
1350
-
1351
- ### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
1352
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
1353
- c11 = M1 * M2 # [Nbatch, Npix_j3]
1354
-
1355
- ### Apply the mask and sum over pixels
1356
- c11 = self.masked_mean(c11, vmask, axis=1,rank=j2) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1357
- return c11
1358
-
1359
- def square(self, x):
1360
- if isinstance(x, scat_cov1D):
1361
- if x.S1 is None:
1362
- return scat_cov1D(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1363
- self.backend.bk_square(self.backend.bk_abs(x.C01)),
1364
- self.backend.bk_square(self.backend.bk_abs(x.C11)),backend=self.backend)
1365
- else:
1366
- return scat_cov1D(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1367
- self.backend.bk_square(self.backend.bk_abs(x.C01)),
1368
- self.backend.bk_square(self.backend.bk_abs(x.C11)),
1369
- s1=self.backend.bk_square(self.backend.bk_abs(x.S1)),backend=self.backend)
1370
- else:
1371
- return self.backend.bk_abs(self.backend.bk_square(x))
1372
-
1373
- def sqrt(self, x):
1374
- if isinstance(x, scat_cov1D):
1375
- if x.S1 is None:
1376
- return scat_cov1D(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1377
- self.backend.bk_sqrt(self.backend.bk_abs(x.C01)),
1378
- self.backend.bk_sqrt(self.backend.bk_abs(x.C11)),backend=self.backend)
1379
- else:
1380
- return scat_cov1D(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1381
- self.backend.bk_sqrt(self.backend.bk_abs(x.C01)),
1382
- self.backend.bk_sqrt(self.backend.bk_abs(x.C11)),
1383
- s1=self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),backend=self.backend)
1384
- else:
1385
- return self.backend.bk_abs(self.backend.bk_sqrt(x))
1386
-
1387
- def reduce_mean(self, x):
1388
- if isinstance(x, scat_cov1D):
1389
- if x.S1 is None:
1390
- result = (self.backend.bk_reduce_mean(self.backend.bk_abs(x.P00)) + \
1391
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.C01)) + \
1392
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.C11)))/3
1393
- else:
1394
- result = (self.backend.bk_reduce_mean(self.backend.bk_abs(x.P00)) + \
1395
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.S1)) + \
1396
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.C01)) + \
1397
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.C11)))/4
1398
- else:
1399
- return self.backend.bk_reduce_mean(x)
1400
- return result
1401
-
1402
- def reduce_sum(self, x):
1403
-
1404
- if isinstance(x, scat_cov1D):
1405
- if x.S1 is None:
1406
- result = self.backend.bk_reduce_sum(x.P00) + \
1407
- self.backend.bk_reduce_sum(x.C01) + \
1408
- self.backend.bk_reduce_sum(x.C11)
1409
- else:
1410
- result = self.backend.bk_reduce_sum(x.P00) + \
1411
- self.backend.bk_reduce_sum(x.S1) + \
1412
- self.backend.bk_reduce_sum(x.C01) + \
1413
- self.backend.bk_reduce_sum(x.C11)
1414
- else:
1415
- return self.backend.bk_reduce_sum(x)
1416
- return result
1417
-
1418
-
1419
- def ldiff(self,sig,x):
1420
-
1421
- if x.S1 is None:
1422
- if x.C11 is not None:
1423
- return scat_cov1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1424
- x.domult(sig.C01,x.C01)*x.domult(sig.C01,x.C01),
1425
- x.domult(sig.C11,x.C11)*x.domult(sig.C11,x.C11),
1426
- backend=self.backend)
1427
- else:
1428
- return scat_cov1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1429
- x.domult(sig.C01,x.C01)*x.domult(sig.C01,x.C01),
1430
- 0*sig.C01,
1431
- backend=self.backend)
1432
- else:
1433
- if x.C11 is None:
1434
- return scat_cov1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1435
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1436
- x.domult(sig.C01,x.C01)*x.domult(sig.C01,x.C01),
1437
- 0*sig.S1,
1438
- backend=self.backend)
1439
- else:
1440
- return scat_cov1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1441
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1442
- x.domult(sig.C01,x.C01)*x.domult(sig.C01,x.C01),
1443
- x.domult(sig.C11,x.C11)*x.domult(sig.C11,x.C11),
1444
- backend=self.backend)
1445
-
1446
-
1447
- def log(self, x):
1448
- if isinstance(x, scat_cov1D):
5
+ def __init__(self, s0, s2, s3, s4, s1=None, s3p=None, backend=None):
1449
6
 
1450
- if x.S1 is None:
1451
- result = self.backend.bk_log(x.P00) + \
1452
- self.backend.bk_log(x.C01) + \
1453
- self.backend.bk_log(x.C11)
1454
- else:
1455
- result = self.backend.bk_log(x.P00) + \
1456
- self.backend.bk_log(x.S1) + \
1457
- self.backend.bk_log(x.C01) + \
1458
- self.backend.bk_log(x.C11)
1459
- else:
1460
- return self.backend.bk_log(x)
1461
-
1462
- return result
7
+ the_scat = scat(s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend)
8
+ the_scat.set_bk_type("SCAT_COV1D")
9
+ return the_scat
1463
10
 
1464
- @tf.function
1465
- def eval_comp_fast(self, image1, image2=None,mask=None,norm=None, Auto=True,Add_R45=False):
11
+ def fill(self, im, nullval=0):
12
+ return self.fill_1d(im, nullval=nullval)
1466
13
 
1467
- res=self.eval(image1, image2=image2,mask=mask,Auto=Auto,Add_R45=Add_R45)
1468
- return res.P00,res.S1,res.C01,res.C11,res.C10
1469
14
 
1470
- def eval_fast(self, image1, image2=None,mask=None,norm=None, Auto=True,Add_R45=False):
1471
- p0,s1,c01,c11,c10=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,Add_R45=Add_R45)
1472
- return scat_cov1D(p0, c01, c11, s1=s1,c10=c10,backend=self.backend)
1473
-
1474
-
15
+ class funct(scat.funct):
16
+ def __init__(self, *args, **kwargs):
17
+ # Impose que use_2D=True pour la classe scat
18
+ super().__init__(use_1D=True, *args, **kwargs)