foscat 3.0.9__tar.gz → 3.0.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {foscat-3.0.9 → foscat-3.0.14}/PKG-INFO +1 -1
- {foscat-3.0.9 → foscat-3.0.14}/setup.py +1 -1
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/FoCUS.py +140 -27
- foscat-3.0.14/src/foscat/GCNN.py +100 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/Synthesis.py +11 -6
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/backend.py +41 -1
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov.py +135 -20
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/PKG-INFO +1 -1
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/SOURCES.txt +1 -0
- {foscat-3.0.9 → foscat-3.0.14}/README.md +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/setup.cfg +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/CircSpline.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/GetGPUinfo.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/Softmax.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/Spline1D.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/__init__.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/loss_backend_tens.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/loss_backend_torch.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat1D.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat2D.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov1D.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov2D.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov_map.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat/scat_cov_map2D.py +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/dependency_links.txt +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/requires.txt +0 -0
- {foscat-3.0.9 → foscat-3.0.14}/src/foscat.egg-info/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|
|
3
3
|
|
|
4
4
|
setup(
|
|
5
5
|
name='foscat',
|
|
6
|
-
version='3.0.
|
|
6
|
+
version='3.0.14',
|
|
7
7
|
description='Generate synthetic Healpix or 2D data using Cross Scattering Transform' ,
|
|
8
8
|
long_description='Utilize the Cross Scattering Transform (described in https://arxiv.org/abs/2207.12527) to synthesize Healpix or 2D data that is suitable for component separation purposes, such as denoising. \n A demo package for this process can be found at https://github.com/jmdelouis/FOSCAT_DEMO. \n Complete doc can be found at https://foscat-documentation.readthedocs.io/en/latest/index.html. \n\n List of developers : J.-M. Delouis, T. Foulquier, L. Mousset, T. Odaka, F. Paul, E. Allys ' ,
|
|
9
9
|
license='MIT',
|
|
@@ -27,6 +27,7 @@ class FoCUS:
|
|
|
27
27
|
JmaxDelta=0,
|
|
28
28
|
DODIV=False,
|
|
29
29
|
InitWave=None,
|
|
30
|
+
silent=False,
|
|
30
31
|
mpi_size=1,
|
|
31
32
|
mpi_rank=0):
|
|
32
33
|
|
|
@@ -42,20 +43,25 @@ class FoCUS:
|
|
|
42
43
|
self.mpi_size=mpi_size
|
|
43
44
|
self.mpi_rank=mpi_rank
|
|
44
45
|
self.return_data=return_data
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
46
|
+
self.silent=silent
|
|
47
|
+
|
|
48
|
+
if not silent:
|
|
49
|
+
print('================================================')
|
|
50
|
+
print(' START FOSCAT CONFIGURATION')
|
|
51
|
+
print('================================================')
|
|
52
|
+
sys.stdout.flush()
|
|
50
53
|
|
|
51
54
|
self.TEMPLATE_PATH=TEMPLATE_PATH
|
|
52
55
|
if os.path.exists(self.TEMPLATE_PATH)==False:
|
|
53
|
-
|
|
56
|
+
if not silent:
|
|
57
|
+
print('The directory %s to store temporary information for FoCUS does not exist: Try to create it'%(self.TEMPLATE_PATH))
|
|
54
58
|
try:
|
|
55
59
|
os.system('mkdir -p %s'%(self.TEMPLATE_PATH))
|
|
56
|
-
|
|
60
|
+
if not silent:
|
|
61
|
+
print('The directory %s is created')
|
|
57
62
|
except:
|
|
58
|
-
|
|
63
|
+
if not silent:
|
|
64
|
+
print('Impossible to create the directory %s'%(self.TEMPLATE_PATH))
|
|
59
65
|
exit(0)
|
|
60
66
|
|
|
61
67
|
self.number_of_loss=0
|
|
@@ -65,13 +71,15 @@ class FoCUS:
|
|
|
65
71
|
self.padding=padding
|
|
66
72
|
|
|
67
73
|
if OSTEP!=0:
|
|
68
|
-
|
|
74
|
+
if not silent:
|
|
75
|
+
print('OPTION option is deprecated after version 2.0.6. Please use Jmax option')
|
|
69
76
|
JmaxDelta=OSTEP
|
|
70
77
|
else:
|
|
71
78
|
OSTEP=JmaxDelta
|
|
72
79
|
|
|
73
80
|
if JmaxDelta<-1:
|
|
74
|
-
|
|
81
|
+
if not silent:
|
|
82
|
+
print('Warning : Jmax can not be smaller than -1')
|
|
75
83
|
exit(0)
|
|
76
84
|
|
|
77
85
|
self.OSTEP=JmaxDelta
|
|
@@ -103,14 +111,15 @@ class FoCUS:
|
|
|
103
111
|
|
|
104
112
|
self.gpupos=(gpupos+mpi_rank)%self.backend.ngpu
|
|
105
113
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
+
if not silent:
|
|
115
|
+
print('============================================================')
|
|
116
|
+
print('== ==')
|
|
117
|
+
print('== ==')
|
|
118
|
+
print('== RUN ON GPU Rank %d : %s =='%(mpi_rank,self.gpulist[self.gpupos%self.ngpu]))
|
|
119
|
+
print('== ==')
|
|
120
|
+
print('== ==')
|
|
121
|
+
print('============================================================')
|
|
122
|
+
sys.stdout.flush()
|
|
114
123
|
|
|
115
124
|
l_NORIENT=NORIENT
|
|
116
125
|
if DODIV:
|
|
@@ -126,6 +135,8 @@ class FoCUS:
|
|
|
126
135
|
|
|
127
136
|
self.ww_Real = {}
|
|
128
137
|
self.ww_Imag = {}
|
|
138
|
+
self.ww_CNN_Transpose = {}
|
|
139
|
+
self.ww_CNN = {}
|
|
129
140
|
|
|
130
141
|
wwc=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
|
|
131
142
|
wws=np.zeros([KERNELSZ**2,l_NORIENT]).astype(all_type)
|
|
@@ -210,7 +221,8 @@ class FoCUS:
|
|
|
210
221
|
|
|
211
222
|
for i in range(1,6):
|
|
212
223
|
lout=(2**i)
|
|
213
|
-
|
|
224
|
+
if not silent:
|
|
225
|
+
print('Init Wave ',lout)
|
|
214
226
|
|
|
215
227
|
if self.InitWave is None:
|
|
216
228
|
wr,wi,ws,widx=self.init_index(lout)
|
|
@@ -271,6 +283,8 @@ class FoCUS:
|
|
|
271
283
|
self.nest2R4[lout]=None
|
|
272
284
|
self.inv_nest2R[lout]=None
|
|
273
285
|
self.remove_border[lout]=None
|
|
286
|
+
self.ww_CNN_Transpose[lout]=None
|
|
287
|
+
self.ww_CNN[lout]=None
|
|
274
288
|
|
|
275
289
|
self.loss={}
|
|
276
290
|
|
|
@@ -296,7 +310,97 @@ class FoCUS:
|
|
|
296
310
|
res=x
|
|
297
311
|
res[idx]=y[idx]
|
|
298
312
|
return(res)
|
|
313
|
+
|
|
314
|
+
# ---------------------------------------------−---------
|
|
315
|
+
# make the CNN working : index reporjection of the kernel on healpix
|
|
316
|
+
|
|
317
|
+
def init_CNN_index(self,nside,transpose=False):
|
|
318
|
+
l_kernel=int(self.KERNELSZ*self.KERNELSZ)
|
|
319
|
+
weights=self.backend.bk_cast(np.ones([12*nside*nside*l_kernel],dtype='float'))
|
|
320
|
+
try:
|
|
321
|
+
if transpose:
|
|
322
|
+
indices=np.load('%s/FOSCAT_%s_W%d_%d_%d_CNN_Transpose.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
|
|
323
|
+
else:
|
|
324
|
+
indices=np.load('%s/FOSCAT_%s_W%d_%d_%d_CNN.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
|
|
325
|
+
except:
|
|
326
|
+
to,po=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
|
|
327
|
+
x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
|
|
328
|
+
|
|
329
|
+
idx=np.argsort((x-1.0)**2+y**2+z**2)[0:l_kernel]
|
|
330
|
+
tc,pc=hp.pix2ang(nside,idx,nest=True)
|
|
331
|
+
|
|
332
|
+
indices=np.zeros([12*nside*nside,l_kernel,2],dtype='int')
|
|
333
|
+
for k in range(12*nside*nside):
|
|
334
|
+
if k%(nside*nside)==0:
|
|
335
|
+
if not silent:
|
|
336
|
+
print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
|
|
337
|
+
|
|
338
|
+
rot=[po[k]/np.pi*180.0,90+(-to[k])/np.pi*180.0]
|
|
339
|
+
r=hp.Rotator(rot=rot).get_inverse()
|
|
340
|
+
# get the coordinate
|
|
341
|
+
ty,tx=r(tc,pc)
|
|
342
|
+
|
|
343
|
+
indices[k,:,0]=k*l_kernel+np.arange(l_kernel).astype('int')
|
|
344
|
+
indices[k,:,1]=hp.ang2pix(nside,ty,tx,nest=True)
|
|
345
|
+
if transpose:
|
|
346
|
+
indices[:,:,1]=indices[:,:,1]//4
|
|
347
|
+
np.save('%s/FOSCAT_%s_W%d_%d_%d_CNN_Transpose.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),indices)
|
|
348
|
+
if not silent:
|
|
349
|
+
print('Write %s/FOSCAT_%s_W%d_%d_%d_CNN_Transpose.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
|
|
350
|
+
else:
|
|
351
|
+
np.save('%s/FOSCAT_%s_W%d_%d_%d_CNNnpy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside),indices)
|
|
352
|
+
if not silent:
|
|
353
|
+
print('Write %s/FOSCAT_%s_W%d_%d_%d_CNN.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,l_kernel,self.NORIENT,nside))
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
if transpose:
|
|
357
|
+
self.ww_CNN_Transpose[nside]=self.backend.bk_SparseTensor(indices.reshape(12*nside*nside*l_kernel,2),
|
|
358
|
+
weights,[12*nside*nside*l_kernel,
|
|
359
|
+
3*nside*nside])
|
|
360
|
+
else:
|
|
361
|
+
self.ww_CNN[nside]=self.backend.bk_SparseTensor(indices.reshape(12*nside*nside*l_kernel,2),
|
|
362
|
+
weights,[12*nside*nside*l_kernel,
|
|
363
|
+
12*nside*nside])
|
|
364
|
+
|
|
365
|
+
# ---------------------------------------------−---------
|
|
366
|
+
def healpix_layer_transpose(self,im,ww):
|
|
367
|
+
nside=2*int(np.sqrt(im.shape[0]//12))
|
|
368
|
+
l_kernel=self.KERNELSZ*self.KERNELSZ
|
|
369
|
+
|
|
370
|
+
if im.shape[1]!=ww.shape[1]:
|
|
371
|
+
if not self.silent:
|
|
372
|
+
print('Weights channels should be equal to the input image channels')
|
|
373
|
+
return -1
|
|
374
|
+
|
|
375
|
+
if self.ww_CNN_Transpose[nside] is None:
|
|
376
|
+
self.init_CNN_index(nside,transpose=True)
|
|
377
|
+
|
|
378
|
+
tmp=self.backend.bk_sparse_dense_matmul(self.ww_CNN_Transpose[nside],im)
|
|
379
|
+
|
|
380
|
+
density=self.backend.bk_reshape(tmp,[12*nside*nside,l_kernel*im.shape[1]])
|
|
381
|
+
|
|
382
|
+
return self.backend.bk_matmul(density,self.backend.bk_reshape(ww,[l_kernel*im.shape[1],ww.shape[2]]))
|
|
383
|
+
|
|
384
|
+
# ---------------------------------------------−---------
|
|
385
|
+
# ---------------------------------------------−---------
|
|
386
|
+
def healpix_layer(self,im,ww):
|
|
387
|
+
nside=int(np.sqrt(im.shape[0]//12))
|
|
388
|
+
l_kernel=self.KERNELSZ*self.KERNELSZ
|
|
389
|
+
|
|
390
|
+
if im.shape[1]!=ww.shape[1]:
|
|
391
|
+
if not self.silent:
|
|
392
|
+
print('Weights channels should be equal to the input image channels')
|
|
393
|
+
return -1
|
|
394
|
+
|
|
395
|
+
if self.ww_CNN[nside] is None:
|
|
396
|
+
self.init_CNN_index(nside,transpose=False)
|
|
397
|
+
|
|
398
|
+
tmp=self.backend.bk_sparse_dense_matmul(self.ww_CNN[nside],im)
|
|
399
|
+
density=self.backend.bk_reshape(tmp,[12*nside*nside,l_kernel*im.shape[1]])
|
|
299
400
|
|
|
401
|
+
return self.backend.bk_matmul(density,self.backend.bk_reshape(ww,[l_kernel*im.shape[1],ww.shape[2]]))
|
|
402
|
+
# ---------------------------------------------−---------
|
|
403
|
+
|
|
300
404
|
# ---------------------------------------------−---------
|
|
301
405
|
def get_rank(self):
|
|
302
406
|
return(self.rank)
|
|
@@ -332,7 +436,8 @@ class FoCUS:
|
|
|
332
436
|
if self.use_2D:
|
|
333
437
|
ishape=list(im.shape)
|
|
334
438
|
if len(ishape)<axis+2:
|
|
335
|
-
|
|
439
|
+
if not self.silent:
|
|
440
|
+
print('Use of 2D scat with data that has less than 2D')
|
|
336
441
|
exit(0)
|
|
337
442
|
|
|
338
443
|
npix=im.shape[axis]
|
|
@@ -392,7 +497,8 @@ class FoCUS:
|
|
|
392
497
|
if self.use_2D:
|
|
393
498
|
ishape=list(im.shape)
|
|
394
499
|
if len(ishape)<axis+2:
|
|
395
|
-
|
|
500
|
+
if not self.silent:
|
|
501
|
+
print('Use of 2D scat with data that has less than 2D')
|
|
396
502
|
exit(0)
|
|
397
503
|
|
|
398
504
|
if nouty is None:
|
|
@@ -434,7 +540,8 @@ class FoCUS:
|
|
|
434
540
|
lout=int(np.sqrt(im.shape[axis]//12))
|
|
435
541
|
|
|
436
542
|
if self.pix_interp_val[lout][nout] is None:
|
|
437
|
-
|
|
543
|
+
if not self.silent:
|
|
544
|
+
print('compute lout nout',lout,nout)
|
|
438
545
|
th,ph=hp.pix2ang(nout,np.arange(12*nout**2,dtype='int'),nest=True)
|
|
439
546
|
p, w = hp.get_interp_weights(lout,th,ph,nest=True)
|
|
440
547
|
del th
|
|
@@ -791,7 +898,8 @@ class FoCUS:
|
|
|
791
898
|
|
|
792
899
|
for k in range(12*nside*nside):
|
|
793
900
|
if k%(nside*nside)==0:
|
|
794
|
-
|
|
901
|
+
if not self.silent:
|
|
902
|
+
print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
|
|
795
903
|
if nside>scale*2:
|
|
796
904
|
lidx=hp.get_all_neighbours(nside//scale,th[k//(scale*scale)],ph[k//(scale*scale)],nest=True)
|
|
797
905
|
lidx=np.concatenate([lidx,np.array([(k//(scale*scale))])],0)
|
|
@@ -843,7 +951,8 @@ class FoCUS:
|
|
|
843
951
|
wav=w.flatten()
|
|
844
952
|
wwav=wwav.flatten()
|
|
845
953
|
|
|
846
|
-
|
|
954
|
+
if not self.silent:
|
|
955
|
+
print('Write FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside))
|
|
847
956
|
np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice)
|
|
848
957
|
np.save('%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),wav)
|
|
849
958
|
np.save('%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy'%(self.TEMPLATE_PATH,TMPFILE_VERSION,self.KERNELSZ**2,self.NORIENT,nside),indice2)
|
|
@@ -857,7 +966,8 @@ class FoCUS:
|
|
|
857
966
|
self.comp_idx_w25(nside)
|
|
858
967
|
else:
|
|
859
968
|
if self.rank==0:
|
|
860
|
-
|
|
969
|
+
if not self.silent:
|
|
970
|
+
print('Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d'%(KERNELSZ,KERNELSZ))
|
|
861
971
|
exit(0)
|
|
862
972
|
|
|
863
973
|
self.barrier()
|
|
@@ -882,6 +992,7 @@ class FoCUS:
|
|
|
882
992
|
return tmp
|
|
883
993
|
|
|
884
994
|
return wr,wi,ws,tmp
|
|
995
|
+
|
|
885
996
|
|
|
886
997
|
# ---------------------------------------------−---------
|
|
887
998
|
# Compute x [....,a,....] to [....,a*a,....]
|
|
@@ -1140,7 +1251,8 @@ class FoCUS:
|
|
|
1140
1251
|
|
|
1141
1252
|
ishape=list(in_image.shape)
|
|
1142
1253
|
if len(ishape)<axis+2:
|
|
1143
|
-
|
|
1254
|
+
if not self.silent:
|
|
1255
|
+
print('Use of 2D scat with data that has less than 2D')
|
|
1144
1256
|
exit(0)
|
|
1145
1257
|
|
|
1146
1258
|
npix=ishape[axis]
|
|
@@ -1264,7 +1376,8 @@ class FoCUS:
|
|
|
1264
1376
|
|
|
1265
1377
|
ishape=list(in_image.shape)
|
|
1266
1378
|
if len(ishape)<axis+2:
|
|
1267
|
-
|
|
1379
|
+
if not self.silent:
|
|
1380
|
+
print('Use of 2D scat with data that has less than 2D')
|
|
1268
1381
|
exit(0)
|
|
1269
1382
|
|
|
1270
1383
|
npix=ishape[axis]
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pickle
|
|
3
|
+
import foscat.scat_cov as sc
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GCNN:
|
|
7
|
+
|
|
8
|
+
def __init__(self,
|
|
9
|
+
scat_operator=None,
|
|
10
|
+
nparam=1,
|
|
11
|
+
nscale=1,
|
|
12
|
+
chanlist=[],
|
|
13
|
+
in_nside=1,
|
|
14
|
+
nbatch=1,
|
|
15
|
+
SEED=1234,
|
|
16
|
+
filename=None):
|
|
17
|
+
|
|
18
|
+
if filename is not None:
|
|
19
|
+
|
|
20
|
+
outlist=pickle.load(open("%s.pkl"%(filename),"rb"))
|
|
21
|
+
|
|
22
|
+
self.scat_operator=sc.funct(KERNELSZ=outlist[3],all_type=outlist[7])
|
|
23
|
+
self.KERNELSZ= self.scat_operator.KERNELSZ
|
|
24
|
+
self.all_type= self.scat_operator.all_type
|
|
25
|
+
self.npar=outlist[2]
|
|
26
|
+
self.nscale=outlist[5]
|
|
27
|
+
self.chanlist=outlist[0]
|
|
28
|
+
self.in_nside=outlist[4]
|
|
29
|
+
self.nbatch=outlist[1]
|
|
30
|
+
|
|
31
|
+
self.x=self.scat_operator.backend.bk_cast(outlist[6])
|
|
32
|
+
else:
|
|
33
|
+
self.nscale=nscale
|
|
34
|
+
self.nbatch=nbatch
|
|
35
|
+
self.npar=nparam
|
|
36
|
+
self.scat_operator=scat_operator
|
|
37
|
+
|
|
38
|
+
if len(chanlist)!=nscale+1:
|
|
39
|
+
print('len of chanlist (here %d) should of nscale+1 (here %d)'%(len(chanlist),nscale+1))
|
|
40
|
+
exit(0)
|
|
41
|
+
|
|
42
|
+
self.chanlist=chanlist
|
|
43
|
+
self.KERNELSZ= scat_operator.KERNELSZ
|
|
44
|
+
self.all_type= scat_operator.all_type
|
|
45
|
+
self.in_nside=in_nside
|
|
46
|
+
|
|
47
|
+
np.random.seed(SEED)
|
|
48
|
+
self.x=scat_operator.backend.bk_cast(np.random.randn(self.get_number_of_weights())/(self.KERNELSZ*self.KERNELSZ))
|
|
49
|
+
|
|
50
|
+
def save(self,filename):
|
|
51
|
+
|
|
52
|
+
outlist=[self.chanlist, \
|
|
53
|
+
self.nbatch, \
|
|
54
|
+
self.npar, \
|
|
55
|
+
self.KERNELSZ, \
|
|
56
|
+
self.in_nside, \
|
|
57
|
+
self.nscale, \
|
|
58
|
+
self.get_weights().numpy(), \
|
|
59
|
+
self.all_type]
|
|
60
|
+
myout=open("%s.pkl"%(filename),"wb")
|
|
61
|
+
pickle.dump(outlist,myout)
|
|
62
|
+
myout.close()
|
|
63
|
+
|
|
64
|
+
def get_number_of_weights(self):
|
|
65
|
+
totnchan=0
|
|
66
|
+
for i in range(self.nscale):
|
|
67
|
+
totnchan=totnchan+self.chanlist[i]*self.chanlist[i+1]
|
|
68
|
+
return self.npar*12*self.in_nside**2*self.chanlist[0] \
|
|
69
|
+
+totnchan*self.KERNELSZ*self.KERNELSZ+self.chanlist[self.nscale]
|
|
70
|
+
|
|
71
|
+
def set_weights(self,x):
|
|
72
|
+
self.x=x
|
|
73
|
+
|
|
74
|
+
def get_weights(self):
|
|
75
|
+
return self.x
|
|
76
|
+
|
|
77
|
+
def eval(self,param):
|
|
78
|
+
|
|
79
|
+
x=self.x
|
|
80
|
+
|
|
81
|
+
ww=self.scat_operator.backend.bk_reshape(x[0:self.npar*12*self.in_nside**2*self.chanlist[0]], \
|
|
82
|
+
[self.npar,12*self.in_nside**2*self.chanlist[0]])
|
|
83
|
+
|
|
84
|
+
im=self.scat_operator.backend.bk_matmul(self.scat_operator.backend.bk_reshape(param,[1,self.npar]),ww)
|
|
85
|
+
im=self.scat_operator.backend.bk_reshape(im,[12*self.in_nside**2,self.chanlist[0]])
|
|
86
|
+
im=self.scat_operator.backend.bk_relu(im)
|
|
87
|
+
|
|
88
|
+
nn=self.npar*12*self.chanlist[0]*self.in_nside**2
|
|
89
|
+
for k in range(self.nscale):
|
|
90
|
+
ww=self.scat_operator.backend.bk_reshape(x[nn:nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]],
|
|
91
|
+
[self.KERNELSZ*self.KERNELSZ,self.chanlist[k],self.chanlist[k+1]])
|
|
92
|
+
nn=nn+self.KERNELSZ*self.KERNELSZ*self.chanlist[k]*self.chanlist[k+1]
|
|
93
|
+
im=self.scat_operator.healpix_layer_transpose(im,ww)
|
|
94
|
+
im=self.scat_operator.backend.bk_relu(im)
|
|
95
|
+
|
|
96
|
+
ww=self.scat_operator.backend.bk_reshape(x[nn:],[self.chanlist[self.nscale],1])
|
|
97
|
+
im=self.scat_operator.backend.bk_matmul(im,ww)
|
|
98
|
+
|
|
99
|
+
return self.scat_operator.backend.bk_reshape(im,[im.shape[0]])
|
|
100
|
+
|
|
@@ -49,6 +49,7 @@ class Synthesis:
|
|
|
49
49
|
|
|
50
50
|
self.loss_class=loss_list
|
|
51
51
|
self.number_of_loss=len(loss_list)
|
|
52
|
+
self.__iteration__=1234
|
|
52
53
|
self.nlog=0
|
|
53
54
|
self.m_dw, self.v_dw = 0.0, 0.0
|
|
54
55
|
self.beta1 = beta1
|
|
@@ -115,7 +116,7 @@ class Synthesis:
|
|
|
115
116
|
|
|
116
117
|
self.nlog=self.nlog+1
|
|
117
118
|
self.itt2=0
|
|
118
|
-
|
|
119
|
+
|
|
119
120
|
if self.itt%self.EVAL_FREQUENCY==0 and self.mpi_rank==0:
|
|
120
121
|
end = time.time()
|
|
121
122
|
cur_loss='%10.3g ('%(self.ltot[self.ltot!=-1].mean())
|
|
@@ -131,6 +132,7 @@ class Synthesis:
|
|
|
131
132
|
for k in range(info_gpu.shape[0]):
|
|
132
133
|
mess=mess+'[GPU%d %.0f/%.0f MB %.0f%%]'%(k,info_gpu[k,0],info_gpu[k,1],info_gpu[k,2])
|
|
133
134
|
|
|
135
|
+
|
|
134
136
|
print('%sItt %6d L=%s %.3fs %s'%(self.MESSAGE,self.itt,cur_loss,(end-self.start),mess))
|
|
135
137
|
sys.stdout.flush()
|
|
136
138
|
if self.KEEP_TRACK is not None:
|
|
@@ -140,13 +142,14 @@ class Synthesis:
|
|
|
140
142
|
self.start = time.time()
|
|
141
143
|
|
|
142
144
|
self.itt=self.itt+1
|
|
143
|
-
|
|
145
|
+
|
|
144
146
|
# ---------------------------------------------−---------
|
|
145
147
|
def calc_grad(self,in_x):
|
|
146
148
|
|
|
147
149
|
g_tot=None
|
|
148
150
|
l_tot=0.0
|
|
149
151
|
|
|
152
|
+
|
|
150
153
|
if self.do_all_noise and self.totalsz>self.batchsz:
|
|
151
154
|
nstep=self.totalsz//self.batchsz
|
|
152
155
|
else:
|
|
@@ -158,6 +161,7 @@ class Synthesis:
|
|
|
158
161
|
|
|
159
162
|
for istep in range(nstep):
|
|
160
163
|
|
|
164
|
+
|
|
161
165
|
for k in range(self.number_of_loss):
|
|
162
166
|
if self.loss_class[k].batch is None:
|
|
163
167
|
l_batch=None
|
|
@@ -271,6 +275,7 @@ class Synthesis:
|
|
|
271
275
|
self.SHOWGPU=SHOWGPU
|
|
272
276
|
self.axis=axis
|
|
273
277
|
self.in_x_nshape=in_x.shape[0]
|
|
278
|
+
self.seed=1234
|
|
274
279
|
|
|
275
280
|
np.random.seed(self.mpi_rank*7+1234)
|
|
276
281
|
|
|
@@ -347,7 +352,7 @@ class Synthesis:
|
|
|
347
352
|
start_x=x.copy()
|
|
348
353
|
|
|
349
354
|
for iteration in range(NUM_STEP_BIAS):
|
|
350
|
-
|
|
355
|
+
|
|
351
356
|
x,l,i=opt.fmin_l_bfgs_b(self.calc_grad,
|
|
352
357
|
x.astype('float64'),
|
|
353
358
|
callback=self.info_back,
|
|
@@ -357,17 +362,17 @@ class Synthesis:
|
|
|
357
362
|
|
|
358
363
|
# update bias input data
|
|
359
364
|
if iteration<NUM_STEP_BIAS-1:
|
|
360
|
-
if self.mpi_rank==0:
|
|
361
|
-
|
|
365
|
+
#if self.mpi_rank==0:
|
|
366
|
+
# print('%s Hessian restart'%(self.MESSAGE))
|
|
362
367
|
|
|
363
368
|
omap=self.xtractmap(x,axis)
|
|
364
369
|
|
|
365
370
|
for k in range(self.number_of_loss):
|
|
366
371
|
if self.loss_class[k].batch_update is not None:
|
|
367
372
|
self.loss_class[k].batch_update(self.loss_class[k].batch_data,omap)
|
|
373
|
+
if self.loss_class[k].batch is not None:
|
|
368
374
|
l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
|
|
369
375
|
#x=start_x.copy()
|
|
370
|
-
|
|
371
376
|
|
|
372
377
|
if self.mpi_rank==0 and SHOWGPU:
|
|
373
378
|
self.stop_synthesis()
|
|
@@ -297,9 +297,14 @@ class foscat_backend:
|
|
|
297
297
|
if self.BACKEND==self.TENSORFLOW:
|
|
298
298
|
return self.backend.nn.conv1d(x,w, stride=[1,1,1], padding='SAME')
|
|
299
299
|
if self.BACKEND==self.TORCH:
|
|
300
|
+
# Torch not yet done !!!
|
|
300
301
|
return self.backend.nn.conv1d(x,w, stride=1, padding='SAME')
|
|
301
302
|
if self.BACKEND==self.NUMPY:
|
|
302
|
-
|
|
303
|
+
res=np.zeros([x.shape[0],x.shape[1],w.shape[1]],dtype=x.dtype)
|
|
304
|
+
for k in range(w.shape[1]):
|
|
305
|
+
for l in range(w.shape[2]):
|
|
306
|
+
res[:,:,l]+=self.scipy.ndimage.convolve1d(x[:,:,k],w[:,k,l],axis=1)
|
|
307
|
+
return res
|
|
303
308
|
|
|
304
309
|
def bk_flattenR(self,x):
|
|
305
310
|
if self.BACKEND==self.TENSORFLOW or self.BACKEND==self.TORCH:
|
|
@@ -480,6 +485,41 @@ class foscat_backend:
|
|
|
480
485
|
return(self.backend.mean(data,axis))
|
|
481
486
|
if self.BACKEND==self.NUMPY:
|
|
482
487
|
return(np.mean(data,axis))
|
|
488
|
+
|
|
489
|
+
def bk_reduce_min(self,data,axis=None):
|
|
490
|
+
|
|
491
|
+
if axis is None:
|
|
492
|
+
if self.BACKEND==self.TENSORFLOW:
|
|
493
|
+
return(self.backend.reduce_min(data))
|
|
494
|
+
if self.BACKEND==self.TORCH:
|
|
495
|
+
return(self.backend.min(data))
|
|
496
|
+
if self.BACKEND==self.NUMPY:
|
|
497
|
+
return(np.min(data))
|
|
498
|
+
else:
|
|
499
|
+
if self.BACKEND==self.TENSORFLOW:
|
|
500
|
+
return(self.backend.reduce_min(data,axis=axis))
|
|
501
|
+
if self.BACKEND==self.TORCH:
|
|
502
|
+
return(self.backend.min(data,axis))
|
|
503
|
+
if self.BACKEND==self.NUMPY:
|
|
504
|
+
return(np.min(data,axis))
|
|
505
|
+
|
|
506
|
+
def bk_random_seed(self,value):
|
|
507
|
+
|
|
508
|
+
if self.BACKEND==self.TENSORFLOW:
|
|
509
|
+
return(self.backend.random.set_seed(value))
|
|
510
|
+
if self.BACKEND==self.TORCH:
|
|
511
|
+
return(self.backend.random.set_seed(value))
|
|
512
|
+
if self.BACKEND==self.NUMPY:
|
|
513
|
+
return(np.random.seed(value))
|
|
514
|
+
|
|
515
|
+
def bk_random_uniform(self,shape):
|
|
516
|
+
|
|
517
|
+
if self.BACKEND==self.TENSORFLOW:
|
|
518
|
+
return(self.backend.random.uniform(shape))
|
|
519
|
+
if self.BACKEND==self.TORCH:
|
|
520
|
+
return(self.backend.random.uniform(shape))
|
|
521
|
+
if self.BACKEND==self.NUMPY:
|
|
522
|
+
return(np.random.rand(shape))
|
|
483
523
|
|
|
484
524
|
def bk_reduce_std(self,data,axis=None):
|
|
485
525
|
|
|
@@ -1491,8 +1491,111 @@ class funct(FOC.FoCUS):
|
|
|
1491
1491
|
|
|
1492
1492
|
return scat_cov(mS0, mP00, mC01, mC11, s1=mS1,c10=mC10,backend=self.backend), \
|
|
1493
1493
|
scat_cov(sS0, sP00, sC01, sC11, s1=sS1,c10=sC10,backend=self.backend)
|
|
1494
|
+
|
|
1495
|
+
# compute local direction to make the statistical analysis more efficient
|
|
1496
|
+
def stat_cfft(self,im,image2=None,upscale=False,smooth_scale=0):
|
|
1497
|
+
tmp=im
|
|
1498
|
+
if image2 is not None:
|
|
1499
|
+
tmpi2=image2
|
|
1500
|
+
if upscale:
|
|
1501
|
+
l_nside=int(np.sqrt(tmp.shape[1]//12))
|
|
1502
|
+
tmp=self.up_grade(tmp,l_nside*2,axis=1)
|
|
1503
|
+
if image2 is not None:
|
|
1504
|
+
tmpi2=self.up_grade(tmpi2,l_nside*2,axis=1)
|
|
1505
|
+
|
|
1506
|
+
l_nside=int(np.sqrt(tmp.shape[1]//12))
|
|
1507
|
+
nscale=int(np.log(l_nside)/np.log(2))
|
|
1508
|
+
cmat={}
|
|
1509
|
+
cmat2={}
|
|
1510
|
+
for k in range(nscale):
|
|
1511
|
+
sim=self.backend.bk_abs(self.convol(tmp,axis=1))
|
|
1512
|
+
if image2 is not None:
|
|
1513
|
+
sim=self.backend.bk_real(self.backend.bk_L1(self.convol(tmp,axis=1)*self.backend.bk_conjugate(self.convol(tmpi2,axis=1))))
|
|
1514
|
+
else:
|
|
1515
|
+
sim=self.backend.bk_abs(self.convol(tmp,axis=1))
|
|
1516
|
+
|
|
1517
|
+
cc=self.backend.bk_reduce_mean(sim[:,:,0]-sim[:,:,2],0)
|
|
1518
|
+
ss=self.backend.bk_reduce_mean(sim[:,:,1]-sim[:,:,3],0)
|
|
1519
|
+
for m in range(smooth_scale):
|
|
1520
|
+
if cc.shape[0]>12:
|
|
1521
|
+
cc=self.ud_grade_2(self.smooth(cc))
|
|
1522
|
+
ss=self.ud_grade_2(self.smooth(ss))
|
|
1523
|
+
if cc.shape[0]!=tmp.shape[0]:
|
|
1524
|
+
ll_nside=int(np.sqrt(tmp.shape[1]//12))
|
|
1525
|
+
cc=self.up_grade(cc,ll_nside)
|
|
1526
|
+
ss=self.up_grade(ss,ll_nside)
|
|
1527
|
+
phase=np.fmod(np.arctan2(ss.numpy(),cc.numpy())+2*np.pi,2*np.pi)
|
|
1528
|
+
iph=(4*phase/(2*np.pi)).astype('int')
|
|
1529
|
+
alpha=(4*phase/(2*np.pi)-iph)
|
|
1530
|
+
mat=np.zeros([sim.shape[1],4*4])
|
|
1531
|
+
lidx=np.arange(sim.shape[1])
|
|
1532
|
+
for l in range(4):
|
|
1533
|
+
mat[lidx,4*((l+iph)%4)+l]=1.0-alpha
|
|
1534
|
+
mat[lidx,4*((l+iph+1)%4)+l]=alpha
|
|
1535
|
+
|
|
1536
|
+
cmat[k]=self.backend.bk_cast(mat.astype('complex64'))
|
|
1537
|
+
|
|
1538
|
+
mat2=np.zeros([k+1,sim.shape[1],4,4*4])
|
|
1539
|
+
|
|
1540
|
+
for k2 in range(k+1):
|
|
1541
|
+
tmp2=self.backend.bk_repeat(sim,4,axis=-1)
|
|
1542
|
+
sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat.reshape(1,mat.shape[0],16)*tmp2,
|
|
1543
|
+
[sim.shape[0],cmat[k].shape[0],4,4]),2)
|
|
1544
|
+
sim2=self.backend.bk_abs(self.convol(sim2,axis=1))
|
|
1545
|
+
|
|
1546
|
+
cc=self.smooth(self.backend.bk_reduce_mean(sim2[:,:,0]-sim2[:,:,2],0))
|
|
1547
|
+
ss=self.smooth(self.backend.bk_reduce_mean(sim2[:,:,1]-sim2[:,:,3],0))
|
|
1548
|
+
for m in range(smooth_scale):
|
|
1549
|
+
if cc.shape[0]>12:
|
|
1550
|
+
cc=self.ud_grade_2(self.smooth(cc))
|
|
1551
|
+
ss=self.ud_grade_2(self.smooth(ss))
|
|
1552
|
+
if cc.shape[0]!=sim.shape[1]:
|
|
1553
|
+
ll_nside=int(np.sqrt(sim.shape[1]//12))
|
|
1554
|
+
cc=self.up_grade(cc,ll_nside)
|
|
1555
|
+
ss=self.up_grade(ss,ll_nside)
|
|
1556
|
+
|
|
1557
|
+
phase=np.fmod(np.arctan2(ss.numpy(),cc.numpy())+2*np.pi,2*np.pi)
|
|
1558
|
+
"""
|
|
1559
|
+
for k in range(4):
|
|
1560
|
+
hp.mollview(np.fmod(phase+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,1+k))
|
|
1561
|
+
plt.show()
|
|
1562
|
+
exit(0)
|
|
1563
|
+
"""
|
|
1564
|
+
iph=(4*phase/(2*np.pi)).astype('int')
|
|
1565
|
+
alpha=(4*phase/(2*np.pi)-iph)
|
|
1566
|
+
lidx=np.arange(sim.shape[1])
|
|
1567
|
+
for m in range(4):
|
|
1568
|
+
for l in range(4):
|
|
1569
|
+
mat2[k2,lidx,m,4*((l+iph[:,m])%4)+l]=1.0-alpha[:,m]
|
|
1570
|
+
mat2[k2,lidx,m,4*((l+iph[:,m]+1)%4)+l]=alpha[:,m]
|
|
1571
|
+
|
|
1572
|
+
cmat2[k]=self.backend.bk_cast(mat2.astype('complex64'))
|
|
1573
|
+
"""
|
|
1574
|
+
tmp=self.backend.bk_repeat(sim[0],4,axis=1)
|
|
1575
|
+
sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat*tmp,[12*nside**2,4,4]),1)
|
|
1576
|
+
|
|
1577
|
+
cc2=(sim2[:,0]-sim2[:,2])
|
|
1578
|
+
ss2=(sim2[:,1]-sim2[:,3])
|
|
1579
|
+
phase2=np.fmod(np.arctan2(ss2.numpy(),cc2.numpy())+2*np.pi,2*np.pi)
|
|
1580
|
+
|
|
1581
|
+
plt.figure()
|
|
1582
|
+
hp.mollview(phase,cmap='jet',nest=True,hold=False,sub=(2,2,1))
|
|
1583
|
+
hp.mollview(np.fmod(phase2+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,2))
|
|
1584
|
+
plt.figure()
|
|
1585
|
+
for k in range(4):
|
|
1586
|
+
hp.mollview((sim[0,:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,1+k),min=-10,max=10)
|
|
1587
|
+
hp.mollview((sim2[:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,5+k),min=-10,max=10)
|
|
1588
|
+
|
|
1589
|
+
plt.show()
|
|
1590
|
+
"""
|
|
1591
|
+
|
|
1592
|
+
if k<l_nside-1:
|
|
1593
|
+
tmp=self.ud_grade_2(tmp,axis=1)
|
|
1594
|
+
if image2 is not None:
|
|
1595
|
+
tmpi2=self.ud_grade_2(tmpi2,axis=1)
|
|
1596
|
+
return cmat,cmat2
|
|
1494
1597
|
|
|
1495
|
-
def eval(self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False):
|
|
1598
|
+
def eval(self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False,cmat=None,cmat2=None):
|
|
1496
1599
|
"""
|
|
1497
1600
|
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
1498
1601
|
mean of modulus:
|
|
@@ -1677,6 +1780,11 @@ class funct(FOC.FoCUS):
|
|
|
1677
1780
|
####### S1 and P00
|
|
1678
1781
|
### Make the convolution I1 * Psi_j3
|
|
1679
1782
|
conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
1783
|
+
|
|
1784
|
+
if cmat is not None:
|
|
1785
|
+
tmp2=self.backend.bk_repeat(conv1,4,axis=-1)
|
|
1786
|
+
conv1=self.backend.bk_reduce_sum(self.backend.bk_reshape(cmat[j3]*tmp2,[1,cmat[j3].shape[0],4,4]),2)
|
|
1787
|
+
|
|
1680
1788
|
### Take the module M1 = |I1 * Psi_j3|
|
|
1681
1789
|
M1_square = conv1*self.backend.bk_conjugate(conv1) # [Nbatch, Npix_j3, Norient3]
|
|
1682
1790
|
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
@@ -1749,6 +1857,9 @@ class funct(FOC.FoCUS):
|
|
|
1749
1857
|
else: # Cross
|
|
1750
1858
|
### Make the convolution I2 * Psi_j3
|
|
1751
1859
|
conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
1860
|
+
if cmat is not None:
|
|
1861
|
+
tmp2=self.backend.bk_repeat(conv2,4,axis=-1)
|
|
1862
|
+
conv2=self.backend.bk_reduce_sum(self.backend.bk_reshape(cmat[j3]*tmp2,[1,cmat[j3].shape[0],4,4]),2)
|
|
1752
1863
|
### Take the module M2 = |I2 * Psi_j3|
|
|
1753
1864
|
M2_square = conv2*self.backend.bk_conjugate(conv2) # [Nbatch, Npix_j3, Norient3]
|
|
1754
1865
|
M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
|
|
@@ -1852,19 +1963,19 @@ class funct(FOC.FoCUS):
|
|
|
1852
1963
|
### C01_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
1853
1964
|
if not cross:
|
|
1854
1965
|
if calc_var:
|
|
1855
|
-
c01,vc01 = self._compute_C01(j2,
|
|
1966
|
+
c01,vc01 = self._compute_C01(j2,j3,
|
|
1856
1967
|
conv1,
|
|
1857
1968
|
vmask,
|
|
1858
1969
|
M1_dic,
|
|
1859
1970
|
M1convPsi_dic,
|
|
1860
|
-
calc_var=True) # [Nbatch, Nmask, Norient3, Norient2]
|
|
1971
|
+
calc_var=True,cmat2=cmat2) # [Nbatch, Nmask, Norient3, Norient2]
|
|
1861
1972
|
else:
|
|
1862
|
-
c01 = self._compute_C01(j2,
|
|
1973
|
+
c01 = self._compute_C01(j2,j3,
|
|
1863
1974
|
conv1,
|
|
1864
1975
|
vmask,
|
|
1865
1976
|
M1_dic,
|
|
1866
1977
|
M1convPsi_dic,
|
|
1867
|
-
return_data=return_data) # [Nbatch, Nmask, Norient3, Norient2]
|
|
1978
|
+
return_data=return_data,cmat2=cmat2) # [Nbatch, Nmask, Norient3, Norient2]
|
|
1868
1979
|
|
|
1869
1980
|
if return_data:
|
|
1870
1981
|
if C01[j3] is None:
|
|
@@ -1892,31 +2003,31 @@ class funct(FOC.FoCUS):
|
|
|
1892
2003
|
### C10_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
1893
2004
|
else:
|
|
1894
2005
|
if calc_var:
|
|
1895
|
-
c01,vc01 = self._compute_C01(j2,
|
|
2006
|
+
c01,vc01 = self._compute_C01(j2,j3,
|
|
1896
2007
|
conv1,
|
|
1897
2008
|
vmask,
|
|
1898
2009
|
M2_dic,
|
|
1899
2010
|
M2convPsi_dic,
|
|
1900
|
-
calc_var=True)
|
|
1901
|
-
c10,vc10 = self._compute_C01(j2,
|
|
2011
|
+
calc_var=True,cmat2=cmat2)
|
|
2012
|
+
c10,vc10 = self._compute_C01(j2,j3,
|
|
1902
2013
|
conv2,
|
|
1903
2014
|
vmask,
|
|
1904
2015
|
M1_dic,
|
|
1905
2016
|
M1convPsi_dic,
|
|
1906
|
-
calc_var=True)
|
|
2017
|
+
calc_var=True,cmat2=cmat2)
|
|
1907
2018
|
else:
|
|
1908
|
-
c01 = self._compute_C01(j2,
|
|
2019
|
+
c01 = self._compute_C01(j2,j3,
|
|
1909
2020
|
conv1,
|
|
1910
2021
|
vmask,
|
|
1911
2022
|
M2_dic,
|
|
1912
2023
|
M2convPsi_dic,
|
|
1913
|
-
return_data=return_data)
|
|
1914
|
-
c10 = self._compute_C01(j2,
|
|
2024
|
+
return_data=return_data,cmat2=cmat2)
|
|
2025
|
+
c10 = self._compute_C01(j2,j3,
|
|
1915
2026
|
conv2,
|
|
1916
2027
|
vmask,
|
|
1917
2028
|
M1_dic,
|
|
1918
2029
|
M1convPsi_dic,
|
|
1919
|
-
return_data=return_data)
|
|
2030
|
+
return_data=return_data,cmat2=cmat2)
|
|
1920
2031
|
|
|
1921
2032
|
if return_data:
|
|
1922
2033
|
if C01[j3] is None:
|
|
@@ -2077,11 +2188,12 @@ class funct(FOC.FoCUS):
|
|
|
2077
2188
|
self.P2_dic = None
|
|
2078
2189
|
return
|
|
2079
2190
|
|
|
2080
|
-
def _compute_C01(self, j2, conv,
|
|
2191
|
+
def _compute_C01(self, j2, j3,conv,
|
|
2081
2192
|
vmask, M_dic,
|
|
2082
2193
|
MconvPsi_dic,
|
|
2083
2194
|
calc_var=False,
|
|
2084
|
-
return_data=False
|
|
2195
|
+
return_data=False,
|
|
2196
|
+
cmat2=None):
|
|
2085
2197
|
"""
|
|
2086
2198
|
Compute the C01 coefficients (auto or cross)
|
|
2087
2199
|
C01 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
|
|
@@ -2094,7 +2206,10 @@ class funct(FOC.FoCUS):
|
|
|
2094
2206
|
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
2095
2207
|
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
|
|
2096
2208
|
MconvPsi = self.convol(M_dic[j2], axis=1) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
2097
|
-
|
|
2209
|
+
if cmat2 is not None:
|
|
2210
|
+
tmp2=self.backend.bk_repeat(MconvPsi,4,axis=-1)
|
|
2211
|
+
MconvPsi=self.backend.bk_reduce_sum(self.backend.bk_reshape(cmat2[j3][j2]*tmp2,[1,cmat2[j3].shape[1],4,4,4]),3)
|
|
2212
|
+
|
|
2098
2213
|
# Store it so we can use it in C11 computation
|
|
2099
2214
|
MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
2100
2215
|
|
|
@@ -2306,13 +2421,13 @@ class funct(FOC.FoCUS):
|
|
|
2306
2421
|
"""
|
|
2307
2422
|
@tf.function
|
|
2308
2423
|
"""
|
|
2309
|
-
def eval_comp_fast(self, image1, image2=None,mask=None,norm=None, Auto=True):
|
|
2424
|
+
def eval_comp_fast(self, image1, image2=None,mask=None,norm=None, Auto=True,cmat=None,cmat2=None):
|
|
2310
2425
|
|
|
2311
|
-
res=self.eval(image1, image2=image2,mask=mask,Auto=Auto)
|
|
2426
|
+
res=self.eval(image1, image2=image2,mask=mask,Auto=Auto,cmat=cmat,cmat2=cmat2)
|
|
2312
2427
|
return res.S0,res.P00,res.S1,res.C01,res.C11,res.C10
|
|
2313
2428
|
|
|
2314
|
-
def eval_fast(self, image1, image2=None,mask=None,norm=None, Auto=True):
|
|
2315
|
-
s0,p0,s1,c01,c11,c10=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto)
|
|
2429
|
+
def eval_fast(self, image1, image2=None,mask=None,norm=None, Auto=True,cmat=None,cmat2=None):
|
|
2430
|
+
s0,p0,s1,c01,c11,c10=self.eval_comp_fast(image1, image2=image2,mask=mask,Auto=Auto,cmat=cmat,cmat2=cmat2)
|
|
2316
2431
|
return scat_cov(s0, p0, c01, c11, s1=s1,c10=c10,backend=self.backend)
|
|
2317
2432
|
|
|
2318
2433
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|