foscat 3.2.0__py3-none-any.whl → 3.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov1D.old.py DELETED
@@ -1,1547 +0,0 @@
1
- import foscat.FoCUS as FOC
2
- import numpy as np
3
- import foscat.backend as bk
4
- import pickle
5
- import matplotlib.pyplot as plt
6
- import sys
7
-
8
-
9
- # Vérifier si TensorFlow est importé et défini
10
- tf_defined = 'tensorflow' in sys.modules
11
-
12
- if tf_defined:
13
- import tensorflow as tf
14
- tf_function = tf.function # Facultatif : si vous voulez utiliser TensorFlow dans ce script
15
- else:
16
- def tf_function(func):
17
- return func
18
-
19
- def read(filename):
20
- thescat = scat_cov1D(1, 1, 1)
21
- return thescat.read(filename)
22
-
23
- testwarn=0
24
-
25
- class scat_cov1D:
26
- def __init__(self, p00, c01, c11, s1=None, c10=None,backend=None):
27
- self.P00 = p00
28
- self.C01 = c01
29
- self.C11 = c11
30
- self.S1 = s1
31
- self.C10 = c10
32
- self.backend = backend
33
- self.idx1 = None
34
- self.idx2 = None
35
-
36
- def numpy(self):
37
- if self.S1 is None:
38
- s1 = None
39
- else:
40
- s1=self.S1.numpy()
41
- if self.C10 is None:
42
- c10 = None
43
- else:
44
- c10=self.C10.numpy()
45
-
46
- return scat_cov1D((self.P00.numpy()),
47
- (self.C01.numpy()),
48
- (self.C11.numpy()),
49
- s1=s1, c10=c10,backend=self.backend)
50
-
51
- def constant(self):
52
-
53
- if self.S1 is None:
54
- s1 = None
55
- else:
56
- s1=self.backend.constant(self.S1)
57
- if self.C10 is None:
58
- c10 = None
59
- else:
60
- c10=self.backend.constant(self.C10)
61
-
62
- return scat_cov1D(self.backend.constant(self.P00),
63
- self.backend.constant(self.C01),
64
- self.backend.constant(self.C11),
65
- s1=s1, c10=c10,backend=self.backend)
66
- def get_S1(self):
67
- return self.S1
68
-
69
- def get_P00(self):
70
- return self.P00
71
-
72
- def reset_P00(self):
73
- self.P00=0*self.P00
74
-
75
- def get_C01(self):
76
- return self.C01
77
-
78
- def get_C10(self):
79
- return self.C10
80
-
81
- def get_C11(self):
82
- return self.C11
83
-
84
- def get_j_idx(self):
85
- shape=list(self.P00.shape)
86
- if len(shape)==3:
87
- nscale=shape[2]
88
- else:
89
- nscale=shape[3]
90
-
91
- n=nscale*(nscale+1)//2
92
- j1=np.zeros([n],dtype='int')
93
- j2=np.zeros([n],dtype='int')
94
- n=0
95
- for i in range(nscale):
96
- for j in range(i+1):
97
- j1[n]=j
98
- j2[n]=i
99
- n=n+1
100
-
101
- return j1,j2
102
-
103
-
104
- def get_jc11_idx(self):
105
- shape=list(self.P00.shape)
106
- nscale=shape[2]
107
- n=nscale*(nscale-1)*(nscale-2)
108
- j1=np.zeros([n*2],dtype='int')
109
- j2=np.zeros([n*2],dtype='int')
110
- j3=np.zeros([n*2],dtype='int')
111
- n=0
112
- for i in range(nscale):
113
- for j in range(i+1):
114
- for k in range(j+1):
115
- j1[n]=k
116
- j2[n]=j
117
- j3[n]=i
118
- n=n+1
119
- return(j1[0:n],j2[0:n],j3[0:n])
120
-
121
- def __add__(self, other):
122
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
123
- isinstance(other, bool) or isinstance(other, scat_cov1D)
124
-
125
- if self.S1 is None:
126
- s1 = None
127
- else:
128
- if isinstance(other, scat_cov1D):
129
- if other.S1 is None:
130
- s1=None
131
- else:
132
- s1 = self.S1 + other.S1
133
- else:
134
- s1 = self.S1 + other
135
-
136
- if self.C10 is None:
137
- c10 = None
138
- else:
139
- if isinstance(other, scat_cov1D):
140
- if other.C10 is None:
141
- c10=None
142
- else:
143
- c10 = self.doadd(self.C10 , other.C10)
144
- else:
145
- c10 = self.C10 + other
146
-
147
- if self.C11 is None:
148
- c11 = None
149
- else:
150
- if isinstance(other, scat_cov1D):
151
- if other.C11 is None:
152
- c11 = None
153
- else:
154
- c11 = self.doadd(self.C11, other.C11 )
155
- else:
156
- c11 = self.C11+other
157
-
158
- if isinstance(other, scat_cov1D):
159
- return scat_cov1D(self.doadd(self.P00,other.P00),
160
- (self.C01 + other.C01),
161
- c11,s1=s1, c10=c10,backend=self.backend)
162
- else:
163
- return scat_cov1D((self.P00 + other),
164
- (self.C01 + other),
165
- c11,s1=s1, c10=c10,backend=self.backend)
166
-
167
-
168
- def relu(self):
169
-
170
- if self.S1 is None:
171
- s1 = None
172
- else:
173
- s1 = self.backend.bk_relu(self.S1)
174
-
175
- if self.C10 is None:
176
- c10 = None
177
- else:
178
- c10 = self.backend.bk_relu(self.c10)
179
-
180
- if self.C11 is None:
181
- c11 = None
182
- else:
183
- c11 = self.backend.bk_relu(self.c11)
184
-
185
- return scat_cov1D(self.backend.bk_relu(self.P00),
186
- self.backend.bk_relu(self.C01),
187
- c11,
188
- s1=s1,
189
- c10=c10,
190
- backend=self.backend)
191
-
192
- def __radd__(self, other):
193
- return self.__add__(other)
194
-
195
- def __truediv__(self, other):
196
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
197
- isinstance(other, bool) or isinstance(other, scat_cov1D)
198
-
199
- if self.S1 is None:
200
- s1 = None
201
- else:
202
- if isinstance(other, scat_cov1D):
203
- if other.S1 is None:
204
- s1 = None
205
- else:
206
- s1 = self.S1 / other.S1
207
- else:
208
- s1 = self.S1 / other
209
-
210
- if self.C10 is None:
211
- c10 = None
212
- else:
213
- if isinstance(other, scat_cov1D):
214
- if other.C10 is None:
215
- c10 = None
216
- else:
217
- c10 = self.dodiv(self.C10 , other.C10)
218
- else:
219
- c10 = self.C10 / other
220
-
221
- if self.C11 is None:
222
- c11 = None
223
- else:
224
- if isinstance(other, scat_cov1D):
225
- if other.C11 is None:
226
- c11 = None
227
- else:
228
- c11 = self.dodiv(self.C11, other.C11 )
229
- else:
230
- c11 = self.C11/other
231
-
232
- if isinstance(other, scat_cov1D):
233
- return scat_cov1D(self.dodiv(self.P00,other.P00),
234
- (self.C01 / other.C01),
235
- c11,s1=s1, c10=c10,backend=self.backend)
236
- else:
237
- return scat_cov1D((self.P00 / other),
238
- (self.C01 / other),
239
- c11,s1=s1, c10=c10,backend=self.backend)
240
-
241
- def __rtruediv__(self, other):
242
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
243
- isinstance(other, bool) or isinstance(other, scat_cov1D)
244
-
245
- if self.S1 is None:
246
- s1 = None
247
- else:
248
- if isinstance(other, scat_cov1D):
249
- s1 = other.S1 / self.S1
250
- else:
251
- s1 = other/self.S1
252
-
253
- if self.C10 is None:
254
- c10 = None
255
- else:
256
- if isinstance(other, scat_cov1D):
257
- c10 = self.dodiv(other.C10 , self.C10)
258
- else:
259
- c10 = other/self.C10
260
-
261
- if self.C11 is None:
262
- c11 = None
263
- else:
264
- if isinstance(other, scat_cov1D):
265
- if other.C11 is None:
266
- c11 = None
267
- else:
268
- c11 = self.dodiv( other.C11,self.C11 )
269
- else:
270
- c11 = other/self.C11
271
-
272
- if isinstance(other, scat_cov1D):
273
- return scat_cov1D(self.dodiv(other.P00,self.P00),
274
- (other.C01 / self.C01),
275
- c11,s1=s1, c10=c10,backend=self.backend)
276
- else:
277
- return scat_cov1D((other/self.P00 ),
278
- (other/self.C01 ),
279
- (other/self.C11 ),
280
- s1=s1, c10=c10,backend=self.backend)
281
-
282
- def __rsub__(self, other):
283
-
284
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
285
- isinstance(other, bool) or isinstance(other, scat_cov1D)
286
-
287
- if self.S1 is None:
288
- s1 = None
289
- else:
290
- if isinstance(other, scat_cov1D):
291
- if other.S1 is None:
292
- s1 = None
293
- else:
294
- s1 = other.S1 - self.S1
295
- else:
296
- s1 = other - self.S1
297
-
298
- if self.C10 is None:
299
- c10 = None
300
- else:
301
- if isinstance(other, scat_cov1D):
302
- if other.C10 is None:
303
- c10 = None
304
- else:
305
- c10 = self.domin(other.C10 , self.C10 )
306
- else:
307
- c10 = other - self.C10
308
-
309
- if self.C11 is None:
310
- c11 = None
311
- else:
312
- if isinstance(other, scat_cov1D):
313
- if other.C11 is None:
314
- c11 = None
315
- else:
316
- c11 = self.domin( other.C11,self.C11 )
317
- else:
318
- c11 = other - self.C11
319
-
320
- if isinstance(other, scat_cov1D):
321
- return scat_cov1D(self.domin(other.P00,self.P00),
322
- (other.C01 - self.C01),
323
- c11,s1=s1, c10=c10,
324
- backend=self.backend)
325
- else:
326
- return scat_cov1D((other-self.P00),
327
- (other-self.C01),
328
- c11,s1=s1, c10=c10,
329
- backend=self.backend)
330
-
331
- def __sub__(self, other):
332
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
333
- isinstance(other, bool) or isinstance(other, scat_cov1D)
334
-
335
- if self.S1 is None:
336
- s1 = None
337
- else:
338
- if isinstance(other, scat_cov1D):
339
- if other.S1 is None:
340
- s1 = None
341
- else:
342
- s1 = self.S1 - other.S1
343
- else:
344
- s1 = self.S1 - other
345
-
346
- if self.C10 is None:
347
- c10 = None
348
- else:
349
- if isinstance(other, scat_cov1D):
350
- if other.C10 is None:
351
- c10 = None
352
- else:
353
- c10 = self.domin(self.C10 , other.C10)
354
- else:
355
- c10 = self.C10 - other
356
-
357
- if self.C11 is None:
358
- c11 = None
359
- else:
360
- if isinstance(other, scat_cov1D):
361
- if other.C11 is None:
362
- c11 = None
363
- else:
364
- c11 = self.domin(self.C11 , other.C11)
365
- else:
366
- c11 = self.C11 - other
367
-
368
- if isinstance(other, scat_cov1D):
369
- return scat_cov1D(self.domin(self.P00,other.P00),
370
- (self.C01 - other.C01),
371
- c11,
372
- s1=s1, c10=c10,backend=self.backend)
373
- else:
374
- return scat_cov1D((self.P00 - other),
375
- (self.C01 - other),
376
- c11,
377
- s1=s1, c10=c10,backend=self.backend)
378
-
379
- def domult(self,x,y):
380
- try:
381
- return x*y
382
- except:
383
- if x.dtype==y.dtype:
384
- return x*y
385
- if self.backend.bk_is_complex(x):
386
-
387
- return self.backend.bk_complex(self.backend.bk_real(x)*y,self.backend.bk_imag(x)*y)
388
- else:
389
- return self.backend.bk_complex(self.backend.bk_real(y)*x,self.backend.bk_imag(y)*x)
390
-
391
- def dodiv(self,x,y):
392
- try:
393
- return x/y
394
- except:
395
- if x.dtype==y.dtype:
396
- return x/y
397
- if self.backend.bk_is_complex(x):
398
-
399
- return self.backend.bk_complex(self.backend.bk_real(x)/y,self.backend.bk_imag(x)/y)
400
- else:
401
- return self.backend.bk_complex(x/self.backend.bk_real(y),x/self.backend.bk_imag(y))
402
-
403
- def domin(self,x,y):
404
- try:
405
- return x-y
406
- except:
407
- if x.dtype==y.dtype:
408
- return x-y
409
-
410
- if self.backend.bk_is_complex(x):
411
-
412
- return self.backend.bk_complex(self.backend.bk_real(x)-y,self.backend.bk_imag(x)-y)
413
- else:
414
- return self.backend.bk_complex(x-self.backend.bk_real(y),x-self.backend.bk_imag(y))
415
-
416
- def doadd(self,x,y):
417
- try:
418
- return x+y
419
- except:
420
- if x.dtype==y.dtype:
421
- return x+y
422
- if self.backend.bk_is_complex(x):
423
-
424
- return self.backend.bk_complex(self.backend.bk_real(x)+y,self.backend.bk_imag(x)+y)
425
- else:
426
- return self.backend.bk_complex(x+self.backend.bk_real(y),x+self.backend.bk_imag(y))
427
-
428
-
429
- def __mul__(self, other):
430
- assert isinstance(other, float) or isinstance(other, np.float32) or isinstance(other, int) or \
431
- isinstance(other, bool) or isinstance(other, scat_cov1D)
432
-
433
- if self.S1 is None:
434
- s1 = None
435
- else:
436
- if isinstance(other, scat_cov1D):
437
- if other.S1 is None:
438
- s1 = None
439
- else:
440
- s1 = self.S1 * other.S1
441
- else:
442
- s1 = self.S1 * other
443
-
444
- if self.C10 is None:
445
- c10 = None
446
- else:
447
- if isinstance(other, scat_cov1D):
448
- if other.C10 is None:
449
- c10 = None
450
- else:
451
- c10 = self.domult(self.C10 , other.C10)
452
- else:
453
- c10 = self.C10 * other
454
-
455
- if self.C11 is None:
456
- c11 = None
457
- else:
458
- if isinstance(other, scat_cov1D):
459
- if other.C11 is None:
460
- c11 = None
461
- else:
462
- c11 = self.domult(self.C11 , other.C11)
463
- else:
464
- c11 = self.C11 * other
465
-
466
- if isinstance(other, scat_cov1D):
467
- return scat_cov1D(self.domult(self.P00,other.P00),
468
- self.domult(self.C01,other.C01),
469
- c11,
470
- s1=s1, c10=c10,backend=self.backend)
471
- else:
472
- return scat_cov1D((self.P00 * other),
473
- (self.C01 * other),
474
- c11,
475
- s1=s1, c10=c10,backend=self.backend)
476
-
477
-
478
- def __rmul__(self, other):
479
- return self.__mul__(other)
480
-
481
- # ---------------------------------------------−---------
482
- def interp(self,nscale,extend=True,constant=False):
483
-
484
- if nscale+2>self.P00.shape[2]:
485
- print('Can not *interp* %d with a statistic described over %d'%(nscale,self.P00.shape[2]))
486
- return scat_cov1D(self.P00,self.C01,self.C11,s1=self.S1,c10=self.C10,backend=self.backend)
487
-
488
- if self.S1 is not None:
489
- s1=self.S1.numpy()
490
- else:
491
- s1=self.S1
492
-
493
- p0=self.P00.numpy()
494
- for k in range(nscale):
495
- if constant:
496
- if self.S1 is not None:
497
- s1[:,:,nscale-1-k,:]=s1[:,:,nscale-k,:]
498
- p0[:,:,nscale-1-k,:]=p0[:,:,nscale-k,:]
499
- else:
500
- if self.S1 is not None:
501
- s1[:,:,nscale-1-k,:]=np.exp(2*np.log(s1[:,:,nscale-k,:])-np.log(s1[:,:,nscale+1-k,:]))
502
- p0[:,:,nscale-1-k,:]=np.exp(2*np.log(p0[:,:,nscale-k,:])-np.log(p0[:,:,nscale+1-k,:]))
503
-
504
- j1,j2=self.get_j_idx()
505
-
506
- if self.C10 is not None:
507
- c10=self.C10.numpy()
508
- else:
509
- c10=self.C10
510
- c01=self.C01.numpy()
511
-
512
- for k in range(nscale):
513
-
514
- for l in range(nscale-k):
515
- i0=np.where((j1==nscale-1-k-l)*(j2==nscale-1-k))[0]
516
- i1=np.where((j1==nscale-1-k-l)*(j2==nscale -k))[0]
517
- i2=np.where((j1==nscale-1-k-l)*(j2==nscale+1-k))[0]
518
- if constant:
519
- c10[:,:,i0]=c10[:,:,i1]
520
- c01[:,:,i0]=c01[:,:,i1]
521
- else:
522
- c10[:,:,i0]=np.exp(2*np.log(c10[:,:,i1])-np.log(c10[:,:,i2]))
523
- c01[:,:,i0]=np.exp(2*np.log(c01[:,:,i1])-np.log(c01[:,:,i2]))
524
-
525
-
526
- c11=self.C11.numpy()
527
- j1,j2,j3=self.get_jc11_idx()
528
-
529
- for k in range(nscale):
530
-
531
- for l in range(nscale-k):
532
- for m in range(nscale-k-l):
533
- i0=np.where((j1==nscale-1-k-l-m)*(j2==nscale-1-k-l)*(j3==nscale-1-k))[0]
534
- i1=np.where((j1==nscale-1-k-l-m)*(j2==nscale-1-k-l)*(j3==nscale -k))[0]
535
- i2=np.where((j1==nscale-1-k-l-m)*(j2==nscale-1-k-l)*(j3==nscale+1-k))[0]
536
- if constant:
537
- c11[:,:,i0]=c11[:,:,i1]
538
- else:
539
- c11[:,:,i0]=np.exp(2*np.log(c11[:,:,i1])-np.log(c11[:,:,i2]))
540
-
541
- if s1 is not None:
542
- s1=self.backend.constant(s1)
543
- if c10 is not None:
544
- c10=self.backend.constant(c10)
545
-
546
- return scat_cov1D(self.backend.constant(p0),self.backend.constant(c01),
547
- self.backend.constant(c11),s1=s1,c10=c10,backend=self.backend)
548
-
549
- def plot(self, name=None, hold=True, color='blue', lw=1, legend=True):
550
-
551
- import matplotlib.pyplot as plt
552
-
553
- if name is None:
554
- name = ''
555
-
556
- j1,j2=self.get_j_idx()
557
-
558
- if hold:
559
- plt.figure(figsize=(16, 8))
560
-
561
- if self.S1 is not None:
562
- plt.subplot(2, 2, 1)
563
- tmp=abs(self.get_np(self.S1))
564
- test=None
565
- for i1 in range(tmp.shape[0]):
566
- for i2 in range(tmp.shape[1]):
567
- if test is None:
568
- test=1
569
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $S_1$' % (name), lw=lw)
570
- else:
571
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
572
- plt.yscale('log')
573
- plt.legend()
574
- plt.ylabel('S1')
575
- plt.xlabel(r'$j_{1}$')
576
-
577
- test=None
578
- plt.subplot(2, 2, 2)
579
- tmp=abs(self.get_np(self.P00))
580
- for i1 in range(tmp.shape[0]):
581
- for i2 in range(tmp.shape[0]):
582
- if test is None:
583
- test=1
584
- plt.plot(tmp[i1,i2,:],color=color, label=r'%s $P_{00}$' % (name), lw=lw)
585
- else:
586
- plt.plot(tmp[i1,i2,:],color=color, lw=lw)
587
- plt.yscale('log')
588
- plt.ylabel('P00')
589
- plt.xlabel(r'$j_{1}$')
590
- plt.legend()
591
-
592
- ax1=plt.subplot(2, 2, 3)
593
- ax2 = ax1.twiny()
594
- n=0
595
- tmp=abs(self.get_np(self.C01))
596
- lname=r'%s $C_{01}$' % (name)
597
- ax1.set_ylabel(r'$C_{01}$')
598
- if self.C10 is not None:
599
- tmp=abs(self.get_np(self.C01))
600
- lname=r'%s $C_{10}$' % (name)
601
- ax1.set_ylabel(r'$C_{10}$')
602
- test=None
603
- tabx=[]
604
- tabnx=[]
605
- tab2x=[]
606
- tab2nx=[]
607
-
608
- for i0 in range(tmp.shape[0]):
609
- for i1 in range(tmp.shape[1]):
610
- for i2 in range(j1.max()+1):
611
- if j2[j1==i2].shape[0]==1:
612
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2],'.', \
613
- color=color, lw=lw)
614
- else:
615
- if legend and test is None:
616
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
617
- color=color, label=lname, lw=lw)
618
- test=1
619
- ax1.plot(j2[j1==i2]+n,tmp[i0,i1,j1==i2], \
620
- color=color, lw=lw)
621
- tabnx=tabnx+[r'%d'%(k) for k in j2[j1==i2]]
622
- tabx=tabx+[k+n for k in j2[j1==i2]]
623
- tab2x=tab2x+[(j2[j1==i2]+n).mean()]
624
- tab2nx=tab2nx+['%d'%(i2)]
625
- ax1.axvline((j2[j1==i2]+n).max()+0.5,ls=':',color='gray')
626
- n=n+j2[j1==i2].shape[0]-1
627
- plt.yscale('log')
628
- ax1.set_xlim(0,n+2)
629
- ax1.set_xticks(tabx)
630
- ax1.set_xticklabels(tabnx,fontsize=6)
631
- ax1.set_xlabel(r"$j_{2}$",fontsize=6)
632
-
633
- # Move twinned axis ticks and label from top to bottom
634
- ax2.xaxis.set_ticks_position("bottom")
635
- ax2.xaxis.set_label_position("bottom")
636
-
637
- # Offset the twin axis below the host
638
- ax2.spines["bottom"].set_position(("axes", -0.15))
639
-
640
- # Turn on the frame for the twin axis, but then hide all
641
- # but the bottom spine
642
- ax2.set_frame_on(True)
643
- ax2.patch.set_visible(False)
644
-
645
- for sp in ax2.spines.values():
646
- sp.set_visible(False)
647
- ax2.spines["bottom"].set_visible(True)
648
- ax2.set_xlim(0,n+2)
649
- ax2.set_xticks(tab2x)
650
- ax2.set_xticklabels(tab2nx,fontsize=6)
651
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
652
- ax1.legend(frameon=0)
653
-
654
- ax1=plt.subplot(2, 2, 4)
655
- j1,j2,j3=self.get_jc11_idx()
656
- ax2 = ax1.twiny()
657
- n=1
658
- tmp=abs(self.get_np(self.C11))
659
- lname=r'%s $C_{11}$' % (name)
660
- test=None
661
- tabx=[]
662
- tabnx=[]
663
- tab2x=[]
664
- tab2nx=[]
665
- for i0 in range(tmp.shape[0]):
666
- for i1 in range(tmp.shape[1]):
667
- for i2 in range(j1.max()+1):
668
- nprev=n
669
- for i2b in range(j2[j1==i2].max()+1):
670
- idx=np.where((j1==i2)*(j2==i2b))[0]
671
- if len(idx)==1:
672
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx],'.', \
673
- color=color, lw=lw)
674
- else:
675
- if legend and test is None:
676
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx], \
677
- color=color, label=lname, lw=lw)
678
- test=1
679
- ax1.plot(np.arange(len(idx))+n,tmp[i0,i1,idx], \
680
- color=color, lw=lw)
681
- tabnx=tabnx+[r'%d,%d'%(j2[k],j3[k]) for k in idx]
682
- tabx=tabx+[k+n for k in range(len(idx))]
683
- n=n+idx.shape[0]
684
- tab2x=tab2x+[(n+nprev-1)/2]
685
- tab2nx=tab2nx+['%d'%(i2)]
686
- ax1.axvline(n-0.5,ls=':',color='gray')
687
- plt.yscale('log')
688
- ax1.set_ylabel(r'$C_{11}$')
689
- ax1.set_xticks(tabx)
690
- ax1.set_xticklabels(tabnx,fontsize=6)
691
- ax1.set_xlabel(r"$j_{2},j_{3}$",fontsize=6)
692
- ax1.set_xlim(0,n)
693
-
694
- # Move twinned axis ticks and label from top to bottom
695
- ax2.xaxis.set_ticks_position("bottom")
696
- ax2.xaxis.set_label_position("bottom")
697
-
698
- # Offset the twin axis below the host
699
- ax2.spines["bottom"].set_position(("axes", -0.15))
700
-
701
- # Turn on the frame for the twin axis, but then hide all
702
- # but the bottom spine
703
- ax2.set_frame_on(True)
704
- ax2.patch.set_visible(False)
705
-
706
- for sp in ax2.spines.values():
707
- sp.set_visible(False)
708
- ax2.spines["bottom"].set_visible(True)
709
- ax2.set_xlim(0,n)
710
- ax2.set_xticks(tab2x)
711
- ax2.set_xticklabels(tab2nx,fontsize=6)
712
- ax2.set_xlabel(r"$j_{1}$",fontsize=6)
713
- ax1.legend(frameon=0)
714
-
715
- def get_np(self, x):
716
- if x is not None:
717
- if isinstance(x, np.ndarray):
718
- return x
719
- else:
720
- return x.numpy()
721
- else:
722
- return None
723
-
724
- def save(self, filename):
725
-
726
- outlist=[self.get_np(self.S1), \
727
- self.get_np(self.C10), \
728
- self.get_np(self.C01), \
729
- self.get_np(self.C11), \
730
- self.get_np(self.P00)]
731
-
732
- myout=open("%s.pkl"%(filename),"wb")
733
- pickle.dump(outlist,myout)
734
- myout.close()
735
-
736
- def read(self, filename):
737
-
738
- outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
739
-
740
- return scat_cov1D(outlist[4], outlist[2], outlist[3], \
741
- s1=outlist[0], c10=outlist[1],backend=self.backend)
742
-
743
- def std(self):
744
- if self.S1 is not None: # Auto
745
- return np.sqrt(((abs(self.get_np(self.S1)).std()) ** 2 +
746
- (abs(self.get_np(self.C01)).std()) ** 2 +
747
- (abs(self.get_np(self.C11)).std()) ** 2 +
748
- (abs(self.get_np(self.P00)).std()) ** 2 ) / 4)
749
- else: # Cross
750
- return np.sqrt(((abs(self.get_np(self.C01)).std()) ** 2 +
751
- (abs(self.get_np(self.C10)).std()) ** 2 +
752
- (abs(self.get_np(self.C11)).std()) ** 2 +
753
- (abs(self.get_np(self.P00)).std()) ** 2) / 4)
754
-
755
- def mean(self):
756
- if self.S1 is not None: # Auto
757
- return (abs(self.get_np(self.S1)).mean() +
758
- abs(self.get_np(self.C01)).mean() +
759
- abs(self.get_np(self.C11)).mean() +
760
- abs(self.get_np(self.P00)).mean()) / 4
761
- else: # Cross
762
- return (abs(self.get_np(self.C01)).mean() +
763
- abs(self.get_np(self.C10)).mean() +
764
- abs(self.get_np(self.C11)).mean() +
765
- abs(self.get_np(self.P00)).mean()) / 4
766
-
767
- def initdx(self,norient):
768
- idx1=np.zeros([norient*norient],dtype='int')
769
- for i in range(norient):
770
- idx1[i*norient:(i+1)*norient]=(np.arange(norient)+i)%norient+i*norient
771
-
772
- idx2=np.zeros([norient*norient*norient],dtype='int')
773
- for i in range(norient):
774
- for j in range(norient):
775
- idx2[i*norient*norient+j*norient:i*norient*norient+(j+1)*norient]= \
776
- ((np.arange(norient)+i)%norient)*norient \
777
- +(np.arange(norient)+i+j)%norient+np.arange(norient)*norient*norient
778
- self.idx1=self.backend.constant(idx1)
779
- self.idx2=self.backend.constant(idx2)
780
-
781
- def get_nscale(self):
782
- return self.P00.shape[2]
783
-
784
- def get_norient(self):
785
- return self.P00.shape[3]
786
-
787
- # ---------------------------------------------−---------
788
- def build_flat(self,table):
789
- shape=table.shape
790
- ndata=1
791
- for k in range(1,len(table.shape)):
792
- ndata=ndata*table.shape[k]
793
- return self.backend.bk_reshape(table,[table.shape[0],ndata])
794
-
795
- # ---------------------------------------------−---------
796
- def flatten(self):
797
- tmp=[self.backend.bk_real(self.build_flat(self.P00))]
798
-
799
- if self.S1 is not None:
800
- tmp=tmp+[self.backend.bk_real(self.build_flat(self.S1))]
801
-
802
- tmp=tmp+[self.backend.bk_real(self.build_flat(self.C01))]
803
-
804
- if self.C10 is not None:
805
- tmp=tmp+[self.backend.bk_real(self.build_flat(self.C10))]
806
-
807
- tmp=tmp+[self.backend.bk_real(self.build_flat(self.C11))]
808
-
809
- if isinstance(self.C11,np.ndarray):
810
- return np.concatenate(tmp,1)
811
- else:
812
- return self.backend.bk_concat(tmp,1)
813
-
814
- # ---------------------------------------------−---------
815
- def flatten_name(self):
816
- tmp=['P00_%d'%(k) for k in range(self.P00.shape[-1])]
817
-
818
- if self.S1 is not None:
819
- tmp=tmp+['S1_%d'%(k) for k in range(self.S1.shape[-1])]
820
-
821
- j1,j2=self.get_j_idx()
822
-
823
- tmp=tmp+['C01_%d-%d'%(j1[k],j2[k]) for k in range(self.C01.shape[-1])]
824
-
825
- if self.C10 is not None:
826
- tmp=tmp+['C10_%d-%d'%(j1[k],j2[k]) for k in range(self.C10.shape[-1])]
827
-
828
- j1,j2,j3=self.get_jc11_idx()
829
- tmp=tmp+['C11_%d-%d-%d'%(j1[k],j2[k],j3[k]) for k in range(self.C11.shape[-1])]
830
-
831
- return tmp
832
-
833
-
834
- def add_data_from_log_slope(self,y,n,ds=3):
835
- if len(y)<ds:
836
- if len(y)==1:
837
- return(np.repeat(y[0],n))
838
- if len(y)==2:
839
- a=np.polyfit(np.arange(2),np.log(y[0:2]),1)
840
- else:
841
- a=np.polyfit(np.arange(ds),np.log(y[0:ds]),1)
842
- return np.exp((np.arange(n)-1-n)*a[0]+a[1])
843
-
844
- def add_data_from_slope(self,y,n,ds=3):
845
- if len(y)<ds:
846
- if len(y)==1:
847
- return(np.repeat(y[0],n))
848
- if len(y)==2:
849
- a=np.polyfit(np.arange(2),y[0:2],1)
850
- else:
851
- a=np.polyfit(np.arange(ds),y[0:ds],1)
852
- return (np.arange(n)-1-n)*a[0]+a[1]
853
-
854
- def up_grade(self,nscale,ds=3):
855
- noff=nscale-self.P00.shape[2]
856
- if noff==0:
857
- return scat_cov1D((self.P00),
858
- (self.C01),
859
- (self.C11),
860
- s1=self.S1,
861
- c10=self.C10,backend=self.backend)
862
-
863
- inscale=self.P00.shape[2]
864
- p00=np.zeros([self.P00.shape[0],self.P00.shape[1],nscale,self.P00.shape[3]],dtype='complex')
865
- p00[:,:,noff:,:]=self.P00.numpy()
866
- for i in range(self.P00.shape[0]):
867
- for j in range(self.P00.shape[1]):
868
- for k in range(self.P00.shape[3]):
869
- p00[i,j,0:noff,k]=self.add_data_from_log_slope(p00[i,j,noff:,k],noff,ds=ds)
870
-
871
- s1=np.zeros([self.S1.shape[0],self.S1.shape[1],nscale,self.S1.shape[3]])
872
- s1[:,:,noff:,:]=self.S1.numpy()
873
- for i in range(self.S1.shape[0]):
874
- for j in range(self.S1.shape[1]):
875
- for k in range(self.S1.shape[3]):
876
- s1[i,j,0:noff,k]=self.add_data_from_log_slope(s1[i,j,noff:,k],noff,ds=ds)
877
-
878
- nout=0
879
- for i in range(1,nscale):
880
- nout=nout+i
881
-
882
- c01=np.zeros([self.C01.shape[0],self.C01.shape[1], \
883
- nout,self.C01.shape[3],self.C01.shape[4]],dtype='complex')
884
-
885
- jo1=np.zeros([nout])
886
- jo2=np.zeros([nout])
887
-
888
- n=0
889
- for i in range(1,nscale):
890
- jo1[n:n+i]=np.arange(i)
891
- jo2[n:n+i]=i
892
- n=n+i
893
-
894
- j1=np.zeros([self.C01.shape[2]])
895
- j2=np.zeros([self.C01.shape[2]])
896
-
897
- n=0
898
- for i in range(1,self.P00.shape[2]):
899
- j1[n:n+i]=np.arange(i)
900
- j2[n:n+i]=i
901
- n=n+i
902
-
903
- for i in range(self.C01.shape[0]):
904
- for j in range(self.C01.shape[1]):
905
- for k in range(self.C01.shape[3]):
906
- for l in range(self.C01.shape[4]):
907
- for ij in range(noff+1,nscale):
908
- idx=np.where(jo2==ij)[0]
909
- c01[i,j,idx[noff:],k,l]=self.C01.numpy()[i,j,j2==ij-noff,k,l]
910
- c01[i,j,idx[:noff],k,l]=self.add_data_from_slope(self.C01.numpy()[i,j,j2==ij-noff,k,l],noff,ds=ds)
911
-
912
- for ij in range(nscale):
913
- idx=np.where(jo1==ij)[0]
914
- if idx.shape[0]>noff:
915
- c01[i,j,idx[:noff],k,l]=self.add_data_from_slope(c01[i,j,idx[noff:],k,l],noff,ds=ds)
916
- else:
917
- c01[i,j,idx,k,l]=np.mean(c01[i,j,jo1==ij-1,k,l])
918
-
919
-
920
- nout=0
921
- for j3 in range(nscale):
922
- for j2 in range(0,j3):
923
- for j1 in range(0,j2):
924
- nout=nout+1
925
-
926
- c11=np.zeros([self.C11.shape[0],self.C11.shape[1], \
927
- nout,self.C11.shape[3], \
928
- self.C11.shape[4],self.C11.shape[5]],dtype='complex')
929
-
930
- jo1=np.zeros([nout])
931
- jo2=np.zeros([nout])
932
- jo3=np.zeros([nout])
933
-
934
- nout=0
935
- for j3 in range(nscale):
936
- for j2 in range(0,j3):
937
- for j1 in range(0,j2):
938
- jo1[nout]=j1
939
- jo2[nout]=j2
940
- jo3[nout]=j3
941
- nout=nout+1
942
-
943
- ncross=self.C11.shape[2]
944
- jj1=np.zeros([ncross])
945
- jj2=np.zeros([ncross])
946
- jj3=np.zeros([ncross])
947
-
948
- n=0
949
- for j3 in range(inscale):
950
- for j2 in range(0,j3):
951
- for j1 in range(0,j2):
952
- jj1[n]=j1
953
- jj2[n]=j2
954
- jj3[n]=j3
955
- n=n+1
956
-
957
- n=0
958
- for j3 in range(nscale):
959
- for j2 in range(j3):
960
- idx=np.where((jj3==j3)*(jj2==j2))[0]
961
- if idx.shape[0]>0:
962
- idx2=np.where((jo3==j3+noff)*(jo2==j2+noff))[0]
963
- for i in range(self.C11.shape[0]):
964
- for j in range(self.C11.shape[1]):
965
- for k in range(self.C11.shape[3]):
966
- for l in range(self.C11.shape[4]):
967
- for m in range(self.C11.shape[5]):
968
- c11[i,j,idx2[noff:],k,l,m]=self.C11.numpy()[i,j,idx,k,l,m]
969
- c11[i,j,idx2[:noff],k,l,m]=self.add_data_from_log_slope(self.C11.numpy()[i,j,idx,k,l,m],noff,ds=ds)
970
-
971
- idx=np.where(abs(c11[0,0,:,0,0,0])==0)[0]
972
- for iii in idx:
973
- iii1=np.where((jo1==jo1[iii]+1)*(jo2==jo2[iii]+1)*(jo3==jo3[iii]+1))[0]
974
- iii2=np.where((jo1==jo1[iii]+2)*(jo2==jo2[iii]+2)*(jo3==jo3[iii]+2))[0]
975
- if iii2.shape[0]>0:
976
- for i in range(self.C11.shape[0]):
977
- for j in range(self.C11.shape[1]):
978
- for k in range(self.C11.shape[3]):
979
- for l in range(self.C11.shape[4]):
980
- for m in range(self.C11.shape[5]):
981
- c11[i,j,iii,k,l,m]=self.add_data_from_slope(c11[i,j,[iii1,iii2],k,l,m],1,ds=2)[0]
982
-
983
- idx=np.where(abs(c11[0,0,:,0,0,0])==0)[0]
984
- for iii in idx:
985
- iii1=np.where((jo1==jo1[iii])*(jo2==jo2[iii])*(jo3==jo3[iii]-1))[0]
986
- iii2=np.where((jo1==jo1[iii])*(jo2==jo2[iii])*(jo3==jo3[iii]-2))[0]
987
- if iii2.shape[0]>0:
988
- for i in range(self.C11.shape[0]):
989
- for j in range(self.C11.shape[1]):
990
- for k in range(self.C11.shape[3]):
991
- for l in range(self.C11.shape[4]):
992
- for m in range(self.C11.shape[5]):
993
- c11[i,j,iii,k,l,m]=self.add_data_from_slope(c11[i,j,[iii1,iii2],k,l,m],1,ds=2)[0]
994
-
995
- return scat_cov1D( (p00),
996
- (c01),
997
- (c11),
998
- s1=(s1),backend=self.backend)
999
-
1000
-
1001
-
1002
- class funct(FOC.FoCUS):
1003
-
1004
- def fill(self,im,nullval=0):
1005
- return self.fill_1d(im,nullval=nullval)
1006
-
1007
- def ud_grade(self,im,nout,axis=0):
1008
- return self.ud_grade_1d(im,nout,axis=axis)
1009
-
1010
- def up_grade(self,im,nout,axis=0):
1011
- return self.up_grade_1d(im,nout,axis=axis)
1012
-
1013
- def smooth(self,data,axis=0):
1014
- return self.smooth_1d(data,axis=axis)
1015
-
1016
- def convol(self,data,axis=0):
1017
- return self.convol_1d(data,axis=axis)
1018
-
1019
- def eval(self, image1, image2=None, mask=None, Auto=True):
1020
- """
1021
- Calculates the scattering correlations for a batch of images. Mean are done over pixels.
1022
- mean of modulus:
1023
- S1 = <|I * Psi_j3|>
1024
- Normalization : take the log
1025
- power spectrum:
1026
- P00 = <|I * Psi_j3|^2>
1027
- Normalization : take the log
1028
- orig. x modulus:
1029
- C01 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
1030
- Normalization : divide by (P00_j2 * P00_j3)^0.5
1031
- modulus x modulus:
1032
- C11 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
1033
- Normalization : divide by (P00_j1 * P00_j2)^0.5
1034
- Parameters
1035
- ----------
1036
- image1: tensor
1037
- Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
1038
- image2: tensor
1039
- Second image. If not None, we compute cross-scattering covariance coefficients.
1040
- mask:
1041
- norm: None or str
1042
- If None no normalization is applied, if 'auto' normalize by the reference P00,
1043
- if 'self' normalize by the current P00.
1044
- all_cross: False or True
1045
- If False compute all the coefficient even the Imaginary part,
1046
- If True return only the terms computable in the auto case.
1047
- Returns
1048
- -------
1049
- S1, P00, C01, C11 normalized
1050
- """
1051
-
1052
- norm=None
1053
- axis=0
1054
-
1055
- # Check input consistency
1056
- if image2 is not None:
1057
- if list(image1.shape)!=list(image2.shape):
1058
- print('The two input image should have the same size to eval Scattering Covariance')
1059
- return None
1060
- if mask is not None:
1061
- if list(image1.shape)!=list(mask.shape)[1:]:
1062
- print('The mask should have the same size ',mask.shape,'than the input image ',image1.shape,'to eval Scattering Covariance')
1063
- return None
1064
-
1065
- ### AUTO OR CROSS
1066
- cross = False
1067
- if image2 is not None:
1068
- cross = True
1069
- all_cross=Auto
1070
- else:
1071
- all_cross=False
1072
-
1073
- ### PARAMETERS
1074
- # determine jmax and nside corresponding to the input map
1075
- im_shape = image1.shape
1076
-
1077
- npix=im_shape[len(image1.shape)-1]
1078
-
1079
- J = int(np.log(npix) / np.log(2)) # Number of j scales
1080
- Jmax = J - self.OSTEP # Number of steps for the loop on scales
1081
-
1082
- ### LOCAL VARIABLES (IMAGES and MASK)
1083
- if len(image1.shape) == 1:
1084
- I1 = self.backend.bk_cast(self.backend.bk_expand_dims(image1, 0)) # Local image1 [Nbatch, Npix]
1085
- if cross:
1086
- I2 = self.backend.bk_cast(self.backend.bk_expand_dims(image2, 0)) # Local image2 [Nbatch, Npix]
1087
- else:
1088
- I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
1089
- if cross:
1090
- I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
1091
-
1092
- if mask is None:
1093
- vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
1094
- else:
1095
- vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
1096
-
1097
- if self.KERNELSZ > 3:
1098
- # if the kernel size is bigger than 3 increase the binning before smoothing
1099
- I1 = self.up_grade_1d(I1, npix * 2, axis=axis+1)
1100
- vmask = self.up_grade_1d(vmask, npix * 2, axis=axis+1)
1101
- if cross:
1102
- I2 = self.up_grade_1d(I2, npix * 2, axis=axis+1)
1103
-
1104
- # Normalize the masks because they have different pixel numbers
1105
- # vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
1106
-
1107
- ### INITIALIZATION
1108
- # Coefficients
1109
- S1, P00, C01, C11, C10 = None, None, None, None, None
1110
-
1111
- # Dictionaries for C01 computation
1112
- M1_dic = {} # M stands for Module M1 = |I1 * Psi|
1113
- if cross:
1114
- M2_dic = {}
1115
-
1116
- # P00 for normalization
1117
- cond_init_P1_dic = ((norm == 'self') or ((norm == 'auto') and (self.P1_dic is None)))
1118
- if norm is None:
1119
- pass
1120
- elif cond_init_P1_dic:
1121
- P1_dic = {}
1122
- if cross:
1123
- P2_dic = {}
1124
- elif (norm == 'auto') and (self.P1_dic is not None):
1125
- P1_dic = self.P1_dic
1126
- if cross:
1127
- P2_dic = self.P2_dic
1128
-
1129
-
1130
- #### COMPUTE S1, P00, C01 and C11
1131
- nside_j3 = npix # NSIDE start (nside_j3 = nside / 2^j3)
1132
-
1133
- for j3 in range(Jmax):
1134
-
1135
- ####### S1 and P00
1136
- ### Make the convolution I1 * Psi_j3
1137
- conv1 = self.convol_1d(I1, axis=axis+1) # [Nbatch, Npix_j3]
1138
- ### Take the module M1 = |I1 * Psi_j3|
1139
- M1_square = conv1*self.backend.bk_conjugate(conv1) # [Nbatch, Npix_j3]
1140
- M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3]
1141
- # Store M1_j3 in a dictionary
1142
- M1_dic[j3] = M1
1143
-
1144
- if not cross: # Auto
1145
- M1_square=self.backend.bk_real(M1_square)
1146
-
1147
- ### P00_auto = < M1^2 >_pix
1148
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1149
- if self.backend.bk_is_complex(M1_square):
1150
- p00 = self.backend.bk_reduce_sum(M1_square[:, None, :]*self.backend.bk_complex(vmask[None,:, :],0*vmask[None,:, :]), axis=2)
1151
- else:
1152
- p00 = self.backend.bk_reduce_sum(M1_square[:, None, :]*vmask[None,:, :], axis=2)
1153
-
1154
- if cond_init_P1_dic:
1155
- # We fill P1_dic with P00 for normalisation of C01 and C11
1156
- P1_dic[j3] = p00 # [Nbatch, Nmask, Norient3]
1157
- if norm == 'auto': # Normalize P00
1158
- p00 /= P1_dic[j3]
1159
-
1160
- # We store P00_auto to return it [Nbatch, Nmask, NP00]
1161
- if P00 is None:
1162
- P00 = p00[:, :, None] # Add a dimension for NP00
1163
- else:
1164
- P00 = self.backend.bk_concat([P00, p00[:, :, None]], axis=axis+2)
1165
-
1166
- #### S1_auto computation
1167
- ### Image 1 : S1 = < M1 >_pix
1168
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1169
- if self.backend.bk_is_complex(M1):
1170
- s1 = self.backend.bk_reduce_sum(M1[:, None, :]*self.backend.bk_complex(vmask[None,:, :],0*vmask[None,:, :]), axis=2)
1171
- else:
1172
- s1 = self.backend.bk_reduce_sum(M1[:, None, :]*vmask[None,:, :], axis=2)
1173
-
1174
- ### Normalize S1
1175
- if norm is not None:
1176
- s1 /= (P1_dic[j3]) ** 0.5
1177
- ### We store S1 for image1 [Nbatch, Nmask, NS1]
1178
- if S1 is None:
1179
- S1 = s1[:, :, None] # Add a dimension for NS1
1180
- else:
1181
- S1 = self.backend.bk_concat([S1, s1[:, :, None]], axis=axis+2)
1182
-
1183
- else: # Cross
1184
- ### Make the convolution I2 * Psi_j3
1185
- conv2 = self.convol_1d(I2, axis=axis+1) # [Nbatch, Npix_j3, Norient3]
1186
- ### Take the module M2 = |I2 * Psi_j3|
1187
- M2_square = conv2*self.backend.bk_conjugate(conv2) # [Nbatch, Npix_j3, Norient3]
1188
- M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
1189
- # Store M2_j3 in a dictionary
1190
- M2_dic[j3] = M2
1191
-
1192
- ### P00_auto = < M2^2 >_pix
1193
- # Not returned, only for normalization
1194
- if cond_init_P1_dic:
1195
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1196
- p1 = self.backend.bk_reduce_sum(M1_square* vmask, axis=axis+1) # [Nbatch, Nmask, Norient3]
1197
- p2 = self.backend.bk_reduce_sum(M2_square* vmask, axis=axis+1) # [Nbatch, Nmask, Norient3]
1198
- # We fill P1_dic with P00 for normalisation of C01 and C11
1199
- P1_dic[j3] = p1 # [Nbatch, Nmask, Norient3]
1200
- P2_dic[j3] = p2 # [Nbatch, Nmask, Norient3]
1201
-
1202
- ### P00_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
1203
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
1204
- p00 = conv1 * self.backend.bk_conjugate(conv2)
1205
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1206
- p00 = self.backend.bk_real(p00)
1207
- p00 = self.backend.bk_reduce_sum(p00*vmask, axis=1)
1208
- print(p00.shape)
1209
- tmp = self.backend.bk_L1(p00) # [Nbatch, Npix_j3, Norient3]
1210
-
1211
- ### Normalize P00_cross
1212
- if norm == 'auto':
1213
- p00 /= (P1_dic[j3] * P2_dic[j3])**0.5
1214
-
1215
- ### Store P00_cross as complex [Nbatch, Nmask, NP00, Norient3]
1216
- if not all_cross:
1217
- p00=self.backend.bk_real(p00)
1218
-
1219
- if P00 is None:
1220
- P00 = p00[:,:,None] # Add a dimension for NP00
1221
- else:
1222
- P00 = self.backend.bk_concat([P00, p00[:,:,None]], axis=axis+2)
1223
-
1224
- #### S1_auto computation
1225
- ### Image 1 : S1 = < M1 >_pix
1226
- # Apply the mask [Nmask, Npix_j3] and average over pixels
1227
- if self.backend.bk_is_complex(s1):
1228
- s1 = self.backend.bk_reduce_sum(tmp[:, None, :]*self.backend.bk_complex(vmask[None,:, :],0*vmask[None,:, :]), axis=2)
1229
- else:
1230
- s1 = self.backend.bk_reduce_sum(tmp[:, None, :]*vmask[None,:, :], axis=2)
1231
-
1232
- ### Normalize S1
1233
- if norm is not None:
1234
- s1 /= (P1_dic[j3]) ** 0.5
1235
- ### We store S1 for image1 [Nbatch, Nmask, NS1]
1236
- if S1 is None:
1237
- S1 = s1[:, :, None] # Add a dimension for NS1
1238
- else:
1239
- S1 = self.backend.bk_concat([S1, s1[:, :, None]], axis=axis+2)
1240
-
1241
- # Initialize dictionaries for |I1*Psi_j| * Psi_j3
1242
- M1convPsi_dic = {}
1243
- if cross:
1244
- # Initialize dictionaries for |I2*Psi_j| * Psi_j3
1245
- M2convPsi_dic = {}
1246
-
1247
- ###### C01
1248
- for j2 in range(0, j3+1): # j2 <= j3
1249
- ### C01_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
1250
- if not cross:
1251
- c01 = self._compute_C01(j2,
1252
- conv1,
1253
- vmask,
1254
- M1_dic,
1255
- M1convPsi_dic) # [Nbatch, Nmask, Norient3, Norient2]
1256
- ### Normalize C01 with P00_j [Nbatch, Nmask, Norient_j]
1257
- if norm is not None:
1258
- c01 /= (P1_dic[j2][:, :, None, :] *
1259
- P1_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
1260
-
1261
- ### Store C01 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
1262
- if C01 is None:
1263
- C01 = c01[:,:,None] # Add a dimension for NC01
1264
- else:
1265
- C01 = self.backend.bk_concat([C01, c01[:, :, None]],
1266
- axis=2) # Add a dimension for NC01
1267
-
1268
- ### C01_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
1269
- ### C10_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
1270
- else:
1271
- c01 = self._compute_C01(j2,
1272
- conv1,
1273
- vmask,
1274
- M2_dic,
1275
- M2convPsi_dic)
1276
- c10 = self._compute_C01(j2,
1277
- conv2,
1278
- vmask,
1279
- M1_dic,
1280
- M1convPsi_dic)
1281
-
1282
- ### Normalize C01 and C10 with P00_j [Nbatch, Nmask, Norient_j]
1283
- if norm is not None:
1284
- c01 /= (P2_dic[j2][:, :, None, :] *
1285
- P1_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
1286
- c10 /= (P1_dic[j2][:, :, None, :] *
1287
- P2_dic[j3][:, :, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2]
1288
-
1289
- ### Store C01 and C10 as a complex [Nbatch, Nmask, NC01, Norient3, Norient2]
1290
- if C01 is None:
1291
- C01 = c01[:, :, None] # Add a dimension for NC01
1292
- else:
1293
- C01 = self.backend.bk_concat([C01,c01[:, :, None]],axis=2) # Add a dimension for NC01
1294
- if C10 is None:
1295
- C10 = c10[:, :, None] # Add a dimension for NC01
1296
- else:
1297
- C10 = self.backend.bk_concat([C10,c10[:, :, None]], axis=2) # Add a dimension for NC01
1298
-
1299
-
1300
-
1301
- ##### C11
1302
- for j1 in range(0, j2+1): # j1 <= j2
1303
- ### C11_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
1304
- if not cross:
1305
- c11 = self._compute_C11(j1, j2, vmask,
1306
- M1convPsi_dic,
1307
- M2convPsi_dic=None) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1308
- ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
1309
- if norm is not None:
1310
- c11 /= (P1_dic[j1][:, :, None, None, :] *
1311
- P1_dic[j2][:, :, None, :,
1312
- None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1313
- ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
1314
- if C11 is None:
1315
- C11 = c11[:, :, None] # Add a dimension for NC11
1316
- else:
1317
- C11 = self.backend.bk_concat([C11,c11[:, :, None]],
1318
- axis=2) # Add a dimension for NC11
1319
-
1320
- ### C11_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
1321
- else:
1322
- c11 = self._compute_C11(j1, j2, vmask,
1323
- M1convPsi_dic,
1324
- M2convPsi_dic=M2convPsi_dic) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1325
- ### Normalize C11 with P00_j [Nbatch, Nmask, Norient_j]
1326
- if norm is not None:
1327
- c11 /= (P1_dic[j1][:, :, None, None, :] *
1328
- P2_dic[j2][:, :, None, :, None]) ** 0.5 # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1329
- ### Store C11 as a complex [Nbatch, Nmask, NC11, Norient3, Norient2, Norient1]
1330
- if C11 is None:
1331
- C11 = c11[:, :, None] # Add a dimension for NC11
1332
- else:
1333
- C11 = self.backend.bk_concat([C11,c11[:, :, None]],
1334
- axis=2) # Add a dimension for NC11
1335
-
1336
- ###### Reshape for next iteration on j3
1337
- ### Image I1,
1338
- # downscale the I1 [Nbatch, Npix_j3]
1339
- if j3 != Jmax - 1:
1340
- I1_smooth = self.smooth_1d(I1, axis=axis+1)
1341
- I1 = self.ud_grade_1d(I1_smooth,I1.shape[axis+1]//2, axis=axis+1)
1342
-
1343
- ### Image I2
1344
- if cross:
1345
- I2_smooth = self.smooth_1d(I2, axis=axis+2)
1346
- I2 = self.ud_grade_1d(I2_smooth,I2.shape[axis+1]//2, axis=axis+1)
1347
-
1348
- ### Modules
1349
- for j2 in range(0, j3 + 1): # j2 =< j3
1350
- ### Dictionary M1_dic[j2]
1351
- M1_smooth = self.smooth_1d(M1_dic[j2], axis=axis+1) # [Nbatch, Npix_j3, Norient3]
1352
- M1_dic[j2] = self.ud_grade_1d(M1_smooth,M1_dic[j2].shape[axis+1]//2, axis=axis+1) # [Nbatch, Npix_j3, Norient3]
1353
-
1354
- ### Dictionary M2_dic[j2]
1355
- if cross:
1356
- M2_smooth = self.smooth_1d(M2_dic[j2], axis=axis+2) # [Nbatch, Npix_j3, Norient3]
1357
- M2_dic[j2] = self.ud_grade_1d(M2_smooth,M2_dic[j2].shape[axis+1]//2, axis=axis+2) # [Nbatch, Npix_j3, Norient3]
1358
- ### Mask
1359
- vmask = self.ud_grade_1d(vmask,vmask.shape[axis+1]//2, axis=1)
1360
-
1361
- if self.mask_thres is not None:
1362
- vmask = self.backend.bk_threshold(vmask,self.mask_thres)
1363
-
1364
- ### NSIDE_j3
1365
- nside_j3 = nside_j3 // 2
1366
-
1367
- ### Store P1_dic and P2_dic in self
1368
- if (norm == 'auto') and (self.P1_dic is None):
1369
- self.P1_dic = P1_dic
1370
- if cross:
1371
- self.P2_dic = P2_dic
1372
-
1373
- if not cross:
1374
- return scat_cov1D(P00, C01, C11, s1=S1,backend=self.backend)
1375
- else:
1376
- return scat_cov1D(P00, C01, C11, c10=C10,backend=self.backend)
1377
-
1378
- def clean_norm(self):
1379
- self.P1_dic = None
1380
- self.P2_dic = None
1381
- return
1382
-
1383
- def _compute_C01(self, j2, conv,
1384
- vmask, M_dic,
1385
- MconvPsi_dic):
1386
- """
1387
- Compute the C01 coefficients (auto or cross)
1388
- C01 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
1389
- Parameters
1390
- ----------
1391
- Returns
1392
- -------
1393
- cc01, sc01: real and imag parts of C01 coeff
1394
- """
1395
- ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
1396
- # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
1397
- MconvPsi = self.convol_1d(M_dic[j2], axis=1) # [Nbatch, Npix_j3, Norient3, Norient2]
1398
-
1399
- # Store it so we can use it in C11 computation
1400
- MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
1401
-
1402
- ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
1403
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
1404
- # cconv, sconv are [Nbatch, Npix_j3, Norient3]
1405
- c01 = conv * self.backend.bk_conjugate(MconvPsi) # [Nbatch, Npix_j3, Norient3, Norient2]
1406
-
1407
- ### Apply the mask [Nmask, Npix_j3] and sum over pixels
1408
- c01 = self.masked_mean(c01, vmask, axis=1,rank=j2) # [Nbatch, Nmask, Norient3, Norient2]
1409
-
1410
- return c01
1411
-
1412
- def _compute_C11(self, j1, j2, vmask,
1413
- M1convPsi_dic,
1414
- M2convPsi_dic=None):
1415
- #### Simplify notations
1416
- M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
1417
-
1418
- # Auto or Cross coefficients
1419
- if M2convPsi_dic is None: # Auto
1420
- M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
1421
- else: # Cross
1422
- M2 = M2convPsi_dic[j2]
1423
-
1424
- ### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
1425
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
1426
- c11 = M1 * M2 # [Nbatch, Npix_j3]
1427
-
1428
- ### Apply the mask and sum over pixels
1429
- c11 = self.masked_mean(c11, vmask, axis=1,rank=j2) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
1430
- return c11
1431
-
1432
- def square(self, x):
1433
- if isinstance(x, scat_cov1D):
1434
- if x.S1 is None:
1435
- return scat_cov1D(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1436
- self.backend.bk_square(self.backend.bk_abs(x.C01)),
1437
- self.backend.bk_square(self.backend.bk_abs(x.C11)),backend=self.backend)
1438
- else:
1439
- return scat_cov1D(self.backend.bk_square(self.backend.bk_abs(x.P00)),
1440
- self.backend.bk_square(self.backend.bk_abs(x.C01)),
1441
- self.backend.bk_square(self.backend.bk_abs(x.C11)),
1442
- s1=self.backend.bk_square(self.backend.bk_abs(x.S1)),backend=self.backend)
1443
- else:
1444
- return self.backend.bk_abs(self.backend.bk_square(x))
1445
-
1446
- def sqrt(self, x):
1447
- if isinstance(x, scat_cov1D):
1448
- if x.S1 is None:
1449
- return scat_cov1D(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1450
- self.backend.bk_sqrt(self.backend.bk_abs(x.C01)),
1451
- self.backend.bk_sqrt(self.backend.bk_abs(x.C11)),backend=self.backend)
1452
- else:
1453
- return scat_cov1D(self.backend.bk_sqrt(self.backend.bk_abs(x.P00)),
1454
- self.backend.bk_sqrt(self.backend.bk_abs(x.C01)),
1455
- self.backend.bk_sqrt(self.backend.bk_abs(x.C11)),
1456
- s1=self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),backend=self.backend)
1457
- else:
1458
- return self.backend.bk_abs(self.backend.bk_sqrt(x))
1459
-
1460
- def reduce_mean(self, x):
1461
- if isinstance(x, scat_cov1D):
1462
- if x.S1 is None:
1463
- result = (self.backend.bk_reduce_mean(self.backend.bk_abs(x.P00)) + \
1464
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.C01)) + \
1465
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.C11)))/3
1466
- else:
1467
- result = (self.backend.bk_reduce_mean(self.backend.bk_abs(x.P00)) + \
1468
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.S1)) + \
1469
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.C01)) + \
1470
- self.backend.bk_reduce_mean(self.backend.bk_abs(x.C11)))/4
1471
- else:
1472
- return self.backend.bk_reduce_mean(x)
1473
- return result
1474
-
1475
- def reduce_sum(self, x):
1476
-
1477
- if isinstance(x, scat_cov1D):
1478
- if x.S1 is None:
1479
- result = self.backend.bk_reduce_sum(x.P00) + \
1480
- self.backend.bk_reduce_sum(x.C01) + \
1481
- self.backend.bk_reduce_sum(x.C11)
1482
- else:
1483
- result = self.backend.bk_reduce_sum(x.P00) + \
1484
- self.backend.bk_reduce_sum(x.S1) + \
1485
- self.backend.bk_reduce_sum(x.C01) + \
1486
- self.backend.bk_reduce_sum(x.C11)
1487
- else:
1488
- return self.backend.bk_reduce_sum(x)
1489
- return result
1490
-
1491
-
1492
- def ldiff(self,sig,x):
1493
-
1494
- if x.S1 is None:
1495
- if x.C11 is not None:
1496
- return scat_cov1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1497
- x.domult(sig.C01,x.C01)*x.domult(sig.C01,x.C01),
1498
- x.domult(sig.C11,x.C11)*x.domult(sig.C11,x.C11),
1499
- backend=self.backend)
1500
- else:
1501
- return scat_cov1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1502
- x.domult(sig.C01,x.C01)*x.domult(sig.C01,x.C01),
1503
- 0*sig.C01,
1504
- backend=self.backend)
1505
- else:
1506
- if x.C11 is None:
1507
- return scat_cov1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1508
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1509
- x.domult(sig.C01,x.C01)*x.domult(sig.C01,x.C01),
1510
- 0*sig.S1,
1511
- backend=self.backend)
1512
- else:
1513
- return scat_cov1D(x.domult(sig.P00,x.P00)*x.domult(sig.P00,x.P00),
1514
- x.domult(sig.S1,x.S1)*x.domult(sig.S1,x.S1),
1515
- x.domult(sig.C01,x.C01)*x.domult(sig.C01,x.C01),
1516
- x.domult(sig.C11,x.C11)*x.domult(sig.C11,x.C11),
1517
- backend=self.backend)
1518
-
1519
-
1520
- def log(self, x):
1521
- if isinstance(x, scat_cov1D):
1522
-
1523
- if x.S1 is None:
1524
- result = self.backend.bk_log(x.P00) + \
1525
- self.backend.bk_log(x.C01) + \
1526
- self.backend.bk_log(x.C11)
1527
- else:
1528
- result = self.backend.bk_log(x.P00) + \
1529
- self.backend.bk_log(x.S1) + \
1530
- self.backend.bk_log(x.C01) + \
1531
- self.backend.bk_log(x.C11)
1532
- else:
1533
- return self.backend.bk_log(x)
1534
-
1535
- return result
1536
-
1537
- @tf_function
1538
- def eval_comp_fast(self, image1, image2=None,mask=None,norm=None, Auto=True,Add_R45=False):
1539
-
1540
- res=self.eval(image1, image2=image2,mask=mask,Auto=Auto,Add_R45=Add_R45)
1541
- return res.P00,res.S1,res.C01,res.C11,res.C10
1542
-
1543
- def eval_fast(self, image1, image2=None,mask=None,norm=None, Auto=True,Add_R45=False):
1544
- p0,s1,c01,c11,c10=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,Add_R45=Add_R45)
1545
- return scat_cov1D(p0, c01, c11, s1=s1,c10=c10,backend=self.backend)
1546
-
1547
-