foscat 2025.5.2__py3-none-any.whl → 2025.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/BkTorch.py CHANGED
@@ -149,11 +149,7 @@ class BkTorch(BackendBase.BackendBase):
149
149
  # -- BACKEND DEFINITION --
150
150
  # ---------------------------------------------−---------
151
151
  def bk_SparseTensor(self, indice, w, dense_shape=[]):
152
- return (
153
- self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
154
- .to_sparse_csr()
155
- .to(self.torch_device)
156
- )
152
+ return self.backend.sparse_coo_tensor(indice.T, w, dense_shape).to_sparse_csr().to(self.torch_device)
157
153
 
158
154
  def bk_stack(self, list, axis=0):
159
155
  return self.backend.stack(list, axis=axis).to(self.torch_device)
@@ -246,13 +242,13 @@ class BkTorch(BackendBase.BackendBase):
246
242
  xr = self.bk_real(x)
247
243
  # xi = self.bk_imag(x)
248
244
 
249
- r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
245
+ r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr + 1E-16)
250
246
  # return r
251
247
  # i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
252
248
 
253
249
  return r
254
250
  else:
255
- return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x)
251
+ return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x + 1E-16)
256
252
 
257
253
  def bk_square_comp(self, x):
258
254
  if x.dtype == self.all_cbk_type:
@@ -391,9 +387,9 @@ class BkTorch(BackendBase.BackendBase):
391
387
  return self.backend.argmax(data)
392
388
 
393
389
  def bk_reshape(self, data, shape):
394
- if isinstance(data, np.ndarray):
395
- return data.reshape(shape)
396
- return data.view(shape)
390
+ #if isinstance(data, np.ndarray):
391
+ # return data.reshape(shape)
392
+ return data.reshape(shape)
397
393
 
398
394
  def bk_repeat(self, data, nn, axis=0):
399
395
  return self.backend.repeat_interleave(data, repeats=nn, dim=axis)
@@ -440,7 +436,9 @@ class BkTorch(BackendBase.BackendBase):
440
436
  return self.backend.zeros(shape, dtype=dtype).to(self.torch_device)
441
437
 
442
438
  def bk_gather(self, data, idx, axis=0):
443
- if axis == 0:
439
+ if axis == -1:
440
+ return data[...,idx]
441
+ elif axis == 0:
444
442
  return data[idx]
445
443
  elif axis == 1:
446
444
  return data[:, idx]
@@ -448,7 +446,7 @@ class BkTorch(BackendBase.BackendBase):
448
446
  return data[:, :, idx]
449
447
  elif axis == 3:
450
448
  return data[:, :, :, idx]
451
- return data[:, :, :, :, idx]
449
+ return data[idx,...]
452
450
 
453
451
  def bk_reverse(self, data, axis=0):
454
452
  return self.backend.flip(data, dims=[axis])
foscat/CNN.py CHANGED
@@ -9,13 +9,12 @@ class CNN:
9
9
 
10
10
  def __init__(
11
11
  self,
12
- scat_operator=None,
13
12
  nparam=1,
14
- nscale=1,
13
+ KERNELSZ=3,
14
+ NORIENT=4,
15
15
  chanlist=[],
16
16
  in_nside=1,
17
17
  n_chan_in=1,
18
- nbatch=1,
19
18
  SEED=1234,
20
19
  filename=None,
21
20
  ):
@@ -31,31 +30,30 @@ class CNN:
31
30
  self.in_nside = outlist[4]
32
31
  self.nbatch = outlist[1]
33
32
  self.n_chan_in = outlist[8]
33
+ self.NORIENT = outlist[9]
34
34
  self.x = self.scat_operator.backend.bk_cast(outlist[6])
35
35
  self.out_nside = self.in_nside // (2**self.nscale)
36
36
  else:
37
- self.nscale = nscale
38
- self.nbatch = nbatch
37
+ self.nscale = len(chanlist)-1
39
38
  self.npar = nparam
40
39
  self.n_chan_in = n_chan_in
41
40
  self.scat_operator = scat_operator
42
- if len(chanlist) != nscale + 1:
43
- print(
44
- "len of chanlist (here %d) should of nscale+1 (here %d)"
45
- % (len(chanlist), nscale + 1)
46
- )
47
- return None
41
+ if self.scat_operator is None:
42
+ self.scat_operator = sc.funct(
43
+ KERNELSZ=KERNELSZ,
44
+ NORIENT=NORIENT)
48
45
 
49
46
  self.chanlist = chanlist
50
- self.KERNELSZ = scat_operator.KERNELSZ
51
- self.all_type = scat_operator.all_type
47
+ self.KERNELSZ = self.scat_operator.KERNELSZ
48
+ self.NORIENT = self.scat_operator.NORIENT
49
+ self.all_type = self.scat_operator.all_type
52
50
  self.in_nside = in_nside
53
51
  self.out_nside = self.in_nside // (2**self.nscale)
54
-
52
+ self.backend = self.scat_operator.backend
55
53
  np.random.seed(SEED)
56
- self.x = scat_operator.backend.bk_cast(
57
- np.random.randn(self.get_number_of_weights())
58
- / (self.KERNELSZ * self.KERNELSZ)
54
+ self.x = self.scat_operator.backend.bk_cast(
55
+ np.random.rand(self.get_number_of_weights())
56
+ / (self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT)
59
57
  )
60
58
 
61
59
  def save(self, filename):
@@ -70,6 +68,7 @@ class CNN:
70
68
  self.get_weights().numpy(),
71
69
  self.all_type,
72
70
  self.n_chan_in,
71
+ self.NORIENT,
73
72
  ]
74
73
 
75
74
  myout = open("%s.pkl" % (filename), "wb")
@@ -82,8 +81,8 @@ class CNN:
82
81
  totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1]
83
82
  return (
84
83
  self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]
85
- + totnchan * self.KERNELSZ * self.KERNELSZ
86
- + self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
84
+ + totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)
85
+ + self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
87
86
  )
88
87
 
89
88
  def set_weights(self, x):
@@ -95,30 +94,32 @@ class CNN:
95
94
  def eval(self, im, indices=None, weights=None):
96
95
 
97
96
  x = self.x
98
- ww = self.scat_operator.backend.bk_reshape(
99
- x[0 : self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]],
100
- [self.KERNELSZ * self.KERNELSZ, self.n_chan_in, self.chanlist[0]],
97
+ ww = self.backend.bk_reshape(
98
+ x[0 : self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]],
99
+ [self.n_chan_in, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[0]],
101
100
  )
102
- nn = self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
101
+ nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
103
102
 
104
103
  im = self.scat_operator.healpix_layer(im, ww)
105
- im = self.scat_operator.backend.bk_relu(im)
104
+ im = self.backend.bk_relu(im)
105
+
106
+ im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
106
107
 
107
108
  for k in range(self.nscale):
108
109
  ww = self.scat_operator.backend.bk_reshape(
109
110
  x[
110
111
  nn : nn
111
112
  + self.KERNELSZ
112
- * self.KERNELSZ
113
+ * (self.KERNELSZ//2+1)
113
114
  * self.chanlist[k]
114
115
  * self.chanlist[k + 1]
115
116
  ],
116
- [self.KERNELSZ * self.KERNELSZ, self.chanlist[k], self.chanlist[k + 1]],
117
+ [self.chanlist[k], self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1]],
117
118
  )
118
119
  nn = (
119
120
  nn
120
121
  + self.KERNELSZ
121
- * self.KERNELSZ
122
+ * (self.KERNELSZ//2)
122
123
  * self.chanlist[k]
123
124
  * self.chanlist[k + 1]
124
125
  )
@@ -129,7 +130,7 @@ class CNN:
129
130
  im, ww, indices=indices[k], weights=weights[k]
130
131
  )
131
132
  im = self.scat_operator.backend.bk_relu(im)
132
- im = self.scat_operator.ud_grade_2(im, axis=0)
133
+ im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
133
134
 
134
135
  ww = self.scat_operator.backend.bk_reshape(
135
136
  x[
@@ -141,11 +142,11 @@ class CNN:
141
142
 
142
143
  im = self.scat_operator.backend.bk_matmul(
143
144
  self.scat_operator.backend.bk_reshape(
144
- im, [1, 12 * self.out_nside**2 * self.chanlist[self.nscale]]
145
+ im, [im.shape[0], im.shape[1] * im.shape[2]]
145
146
  ),
146
147
  ww,
147
148
  )
148
- im = self.scat_operator.backend.bk_reshape(im, [self.npar])
149
+ #im = self.scat_operator.backend.bk_reshape(im, [self.npar])
149
150
  im = self.scat_operator.backend.bk_relu(im)
150
151
 
151
152
  return im
foscat/FoCUS.py CHANGED
@@ -35,7 +35,7 @@ class FoCUS:
35
35
  mpi_rank=0,
36
36
  ):
37
37
 
38
- self.__version__ = "2025.05.2"
38
+ self.__version__ = "2025.06.1"
39
39
  # P00 coeff for normalization for scat_cov
40
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
41
  self.P1_dic = None
@@ -176,6 +176,9 @@ class FoCUS:
176
176
  self.Y_CNN = {}
177
177
  self.Z_CNN = {}
178
178
 
179
+ self.Idx_CNN = {}
180
+ self.Idx_WCNN = {}
181
+
179
182
  self.filters_set = {}
180
183
  self.edge_masks = {}
181
184
 
@@ -500,210 +503,26 @@ class FoCUS:
500
503
  return indices, weights, xc, yc, zc
501
504
 
502
505
  # ---------------------------------------------−---------
503
- def calc_orientation(self, im): # im is [Ndata,12*Nside**2]
504
- nside = int(np.sqrt(im.shape[1] // 12))
505
- l_kernel = self.KERNELSZ * self.KERNELSZ
506
- norient = 32
507
- w = np.zeros([l_kernel, 1, 2 * norient])
508
- ca = np.cos(np.arange(norient) / norient * np.pi)
509
- sa = np.sin(np.arange(norient) / norient * np.pi)
510
- stat = np.zeros([12 * nside**2, norient])
511
-
512
- if self.ww_CNN[nside] is None:
513
- self.init_CNN_index(nside, transpose=False)
514
-
515
- y = self.Y_CNN[nside]
516
- z = self.Z_CNN[nside]
517
-
518
- for k in range(norient):
519
- w[:, 0, k] = np.exp(-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)) * np.cos(
520
- nside * (y * ca[k] + z * sa[k]) * np.pi / 2
521
- )
522
- w[:, 0, k + norient] = np.exp(
523
- -0.5 * nside**2 * ((y) ** 2 + (z) ** 2)
524
- ) * np.sin(nside * (y * ca[k] + z * sa[k]) * np.pi / 2)
525
- w[:, 0, k] = w[:, 0, k] - np.mean(w[:, 0, k])
526
- w[:, 0, k + norient] = w[:, 0, k] - np.mean(w[:, 0, k + norient])
527
-
528
- for k in range(im.shape[0]):
529
- tmp = im[k].reshape(12 * nside**2, 1)
530
- im2 = self.healpix_layer(tmp, w)
531
- stat = stat + im2[:, 0:norient] ** 2 + im2[:, norient:] ** 2
532
-
533
- rotation = (np.argmax(stat, 1)).astype("float") / 32.0 * 180.0
534
-
535
- indices, weights, x, y, z = self.calc_indices_convol(
536
- nside, 9, rotation=rotation
537
- )
538
-
539
- return indices, weights
540
-
541
- def init_CNN_index(self, nside, transpose=False):
542
- l_kernel = int(self.KERNELSZ * self.KERNELSZ)
543
- try:
544
- indices = np.load(
545
- "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
546
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
547
- )
548
- weights = np.load(
549
- "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
550
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
551
- )
552
- xc = np.load(
553
- "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
554
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
555
- )
556
- yc = np.load(
557
- "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
558
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
559
- )
560
- zc = np.load(
561
- "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
562
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
563
- )
564
- except:
565
- indices, weights, xc, yc, zc = self.calc_indices_convol(nside, l_kernel)
566
- np.save(
567
- "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
568
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
569
- indices,
570
- )
571
- np.save(
572
- "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
573
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
574
- weights,
575
- )
576
- np.save(
577
- "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
578
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
579
- xc,
580
- )
581
- np.save(
582
- "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
583
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
584
- yc,
585
- )
586
- np.save(
587
- "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
588
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
589
- zc,
590
- )
591
- if not self.silent:
592
- print(
593
- "Write %s/FOSCAT_%s_W%d_%d_%d_CNNV2.npy"
594
- % (
595
- self.TEMPLATE_PATH,
596
- TMPFILE_VERSION,
597
- l_kernel,
598
- self.NORIENT,
599
- nside,
600
- )
601
- )
602
-
603
- self.X_CNN[nside] = xc
604
- self.Y_CNN[nside] = yc
605
- self.Z_CNN[nside] = zc
606
- self.ww_CNN[nside] = self.backend.bk_SparseTensor(
607
- indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
608
- )
609
-
610
- # ---------------------------------------------−---------
611
- def healpix_layer_coord(self, im, axis=0):
612
- nside = int(np.sqrt(im.shape[axis] // 12))
613
- if self.ww_CNN[nside] is None:
614
- self.init_CNN_index(nside)
615
- return self.X_CNN[nside], self.Y_CNN[nside], self.Z_CNN[nside]
616
-
617
- # ---------------------------------------------−---------
618
- def healpix_layer_transpose(self, im, ww, indices=None, weights=None, axis=0):
619
- nside = int(np.sqrt(im.shape[axis] // 12))
620
-
621
- if im.shape[1 + axis] != ww.shape[1]:
622
- if not self.silent:
623
- print("Weights channels should be equal to the input image channels")
624
- return -1
625
- if axis == 1:
626
- results = []
627
-
628
- for k in range(im.shape[0]):
629
-
630
- tmp = self.healpix_layer(
631
- im[k], ww, indices=indices, weights=weights, axis=0
632
- )
633
- tmp = self.backend.bk_reshape(
634
- self.up_grade(tmp, 2 * nside), [12 * 4 * nside**2, ww.shape[2]]
635
- )
636
-
637
- results.append(tmp)
638
-
639
- return self.backend.bk_stack(results, axis=0)
640
- else:
641
- tmp = self.healpix_layer(
642
- im, ww, indices=indices, weights=weights, axis=axis
643
- )
644
-
645
- return self.up_grade(tmp, 2 * nside)
646
-
647
506
  # ---------------------------------------------−---------
648
- # ---------------------------------------------−---------
649
- def healpix_layer(self, im, ww, indices=None, weights=None, axis=0):
650
- nside = int(np.sqrt(im.shape[axis] // 12))
651
- l_kernel = self.KERNELSZ * self.KERNELSZ
652
-
653
- if im.shape[1 + axis] != ww.shape[1]:
654
- if not self.silent:
655
- print("Weights channels should be equal to the input image channels")
656
- return -1
657
-
507
+ def healpix_layer(self, im, ww, indices=None, weights=None):
508
+ #ww [N_i,NORIENT,KERNELSZ*KERNELSZ//2,N_o,NORIENT]
509
+ #im [N_batch,N_i, NORIENT,N]
510
+ nside=int(np.sqrt(im.shape[-1]//12))
658
511
  if indices is None:
659
- if self.ww_CNN[nside] is None:
660
- self.init_CNN_index(nside, transpose=False)
661
- mat = self.ww_CNN[nside]
662
- else:
663
- if weights is None:
664
- print(
665
- "healpix_layer : If indices is not none weights should be specify"
666
- )
667
- return 0
668
-
669
- mat = self.backend.bk_SparseTensor(
670
- indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
671
- )
672
-
673
- if axis == 1:
674
- results = []
675
-
676
- for k in range(im.shape[0]):
677
-
678
- tmp = self.backend.bk_sparse_dense_matmul(mat, im[k])
679
-
680
- density = self.backend.bk_reshape(
681
- tmp, [12 * nside * nside, l_kernel * im.shape[1 + axis]]
682
- )
683
-
684
- density = self.backend.bk_matmul(
685
- density,
686
- self.backend.bk_reshape(
687
- ww, [l_kernel * im.shape[1 + axis], ww.shape[2]]
688
- ),
689
- )
690
-
691
- results.append(
692
- self.backend.bk_reshape(density, [12 * nside**2, ww.shape[2]])
693
- )
694
-
695
- return self.backend.bk_stack(results, axis=0)
696
- else:
697
- tmp = self.backend.bk_sparse_dense_matmul(mat, im)
698
-
699
- density = self.backend.bk_reshape(
700
- tmp, [12 * nside * nside, l_kernel * im.shape[1]]
701
- )
702
-
703
- return self.backend.bk_matmul(
704
- density,
705
- self.backend.bk_reshape(ww, [l_kernel * im.shape[1], ww.shape[2]]),
706
- )
512
+ if (nside,self.NORIENT,self.KERNELSZ) not in self.ww_CNN:
513
+ self.init_index_cnn(nside,self.NORIENT)
514
+ indices = self.Idx_CNN[(nside,self.NORIENT,self.KERNELSZ)]
515
+ mat = self.Idx_WCNN[(nside,self.NORIENT,self.KERNELSZ)]
516
+
517
+ wim = self.backend.bk_gather(im,indices.flatten(),axis=3) #[N_batch,N_i,NORIENT,K*(K+1),N_o,NORIENT,N,N_w]
518
+
519
+ wim = self.backend.bk_reshape(wim,[im.shape[0],im.shape[1],im.shape[2]]+list(indices.shape))*mat[None,...]
520
+ #win is [N_batch,N_i, NORIENT,K*(K+1),1, NORIENT,N,N_w]
521
+ #ww is [1, N_i, NORIENT,K*(K+1),N_o,NORIENT]
522
+ wim = self.backend.bk_reduce_sum(wim[:,:,:,:,None]*ww[None,:,:,:,:,:,None,None],[1,2,3])
523
+
524
+ wim = self.backend.bk_reduce_sum(wim,-1)
525
+ return self.backend.bk_reshape(wim,[im.shape[0],ww.shape[3],ww.shape[4],im.shape[-1]])
707
526
 
708
527
  # ---------------------------------------------−---------
709
528
 
@@ -1775,6 +1594,232 @@ class FoCUS:
1775
1594
 
1776
1595
  return wr, wi, ws, tmp
1777
1596
 
1597
+
1598
+ # ---------------------------------------------−---------
1599
+ def init_index_cnn(self, nside, NORIENT=4,kernel=-1, cell_ids=None):
1600
+
1601
+ if kernel == -1:
1602
+ l_kernel = self.KERNELSZ
1603
+ else:
1604
+ l_kernel = kernel
1605
+
1606
+ if cell_ids is not None:
1607
+ ncell = cell_ids.shape[0]
1608
+ else:
1609
+ ncell = 12 * nside * nside
1610
+
1611
+ try:
1612
+
1613
+ if cell_ids is not None:
1614
+ tmp = np.load(
1615
+ "%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
1616
+ % (
1617
+ self.TEMPLATE_PATH,
1618
+ TMPFILE_VERSION,
1619
+ l_kernel**2,
1620
+ NORIENT,
1621
+ nside, # if cell_ids computes the index
1622
+ )
1623
+ )
1624
+
1625
+ else:
1626
+ tmp = np.load(
1627
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1628
+ % (
1629
+ self.TEMPLATE_PATH,
1630
+ TMPFILE_VERSION,
1631
+ l_kernel**2,
1632
+ NORIENT,
1633
+ nside, # if cell_ids computes the index
1634
+ )
1635
+ )
1636
+ except:
1637
+
1638
+ pw = 8.0
1639
+ pw2 = 1.0
1640
+ threshold = 1e-3
1641
+
1642
+ if l_kernel == 5:
1643
+ pw = 8.0
1644
+ pw2 = 0.5
1645
+ threshold = 2e-4
1646
+
1647
+ elif l_kernel == 3:
1648
+ pw = 8.0
1649
+ pw2 = 1.0
1650
+ threshold = 1e-3
1651
+
1652
+ elif l_kernel == 7:
1653
+ pw = 8.0
1654
+ pw2 = 0.25
1655
+ threshold = 4e-5
1656
+
1657
+ n_weights = self.KERNELSZ*(self.KERNELSZ//2+1)
1658
+
1659
+ if cell_ids is not None:
1660
+ if not isinstance(cell_ids, np.ndarray):
1661
+ cell_ids = self.backend.to_numpy(cell_ids)
1662
+ th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1663
+ x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1664
+
1665
+ t, p = hp.pix2ang(nside, cell_ids, nest=True)
1666
+ phi = [p[k] / np.pi * 180 for k in range(ncell)]
1667
+ thi = [t[k] / np.pi * 180 for k in range(ncell)]
1668
+
1669
+ indice = np.zeros([n_weights, NORIENT, ncell,4], dtype="int")
1670
+
1671
+ wav = np.zeros([n_weights, NORIENT, ncell,4], dtype="float")
1672
+
1673
+ else:
1674
+
1675
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1676
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1677
+
1678
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1679
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1680
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1681
+
1682
+ indice = np.zeros(
1683
+ [n_weights, NORIENT, 12 * nside * nside,4], dtype="int"
1684
+ )
1685
+ wav = np.zeros(
1686
+ [n_weights, NORIENT, 12 * nside * nside,4], dtype="float"
1687
+ )
1688
+ iv = 0
1689
+ iv2 = 0
1690
+
1691
+ for iii in range(ncell):
1692
+ if cell_ids is None:
1693
+ if iii % (nside * nside) == nside * nside - 1:
1694
+ if not self.silent:
1695
+ print(
1696
+ "Pre-compute nside=%6d %.2f%%"
1697
+ % (nside, 100 * iii / (12 * nside * nside))
1698
+ )
1699
+
1700
+ if cell_ids is not None:
1701
+ hidx = np.where(
1702
+ (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1703
+ < (2 * np.pi / nside) ** 2
1704
+ )[0]
1705
+ else:
1706
+ hidx = hp.query_disc(
1707
+ nside,
1708
+ [x[iii], y[iii], z[iii]],
1709
+ 2 * np.pi / nside,
1710
+ nest=True,
1711
+ )
1712
+
1713
+ R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1714
+
1715
+ t2, p2 = R(th[hidx], ph[hidx])
1716
+
1717
+ vec2 = hp.ang2vec(t2, p2)
1718
+
1719
+ x2 = vec2[:, 0]
1720
+ y2 = vec2[:, 1]
1721
+ z2 = vec2[:, 2]
1722
+
1723
+ for l_rotation in range(NORIENT):
1724
+
1725
+ angle = (
1726
+ l_rotation / 4.0 * np.pi
1727
+ - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1728
+ - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1729
+ )
1730
+
1731
+
1732
+ axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1733
+ axes2 = -y2 * np.sin(angle) - x2 * np.cos(angle)
1734
+
1735
+ for k_weights in range(self.KERNELSZ//2+1):
1736
+ for l_weights in range(self.KERNELSZ):
1737
+
1738
+ val=np.exp(-(pw*(axes2*(nside)-(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))+ \
1739
+ np.exp(-(pw*(axes2*(nside)+(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))
1740
+
1741
+ idx = np.argsort(-val)
1742
+ idx = idx[0:4]
1743
+
1744
+ nval = len(idx)
1745
+ val=val[idx]
1746
+
1747
+ r = abs(val).sum()
1748
+
1749
+ if r > 0:
1750
+ val = val / r
1751
+
1752
+ indice[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = hidx[idx]
1753
+ wav[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = val
1754
+
1755
+ if not self.silent:
1756
+ print("Kernel Size ", iv / (NORIENT * 12 * nside * nside))
1757
+
1758
+ if cell_ids is None:
1759
+ if not self.silent:
1760
+ print(
1761
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1762
+ % (TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside)
1763
+ )
1764
+ np.save(
1765
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1766
+ % (
1767
+ self.TEMPLATE_PATH,
1768
+ TMPFILE_VERSION,
1769
+ self.KERNELSZ**2,
1770
+ NORIENT,
1771
+ nside,
1772
+ ),
1773
+ indice,
1774
+ )
1775
+ np.save(
1776
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1777
+ % (
1778
+ self.TEMPLATE_PATH,
1779
+ TMPFILE_VERSION,
1780
+ self.KERNELSZ**2,
1781
+ NORIENT,
1782
+ nside,
1783
+ ),
1784
+ wav,
1785
+ )
1786
+
1787
+ if cell_ids is None:
1788
+ self.barrier()
1789
+ if self.use_2D:
1790
+ tmp = np.load(
1791
+ "%s/W%d_%s_%d_IDX.npy"
1792
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1793
+ )
1794
+ else:
1795
+ tmp = np.load(
1796
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1797
+ % (
1798
+ self.TEMPLATE_PATH,
1799
+ TMPFILE_VERSION,
1800
+ self.KERNELSZ**2,
1801
+ NORIENT,
1802
+ nside,
1803
+ )
1804
+ )
1805
+ wav = np.load(
1806
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1807
+ % (
1808
+ self.TEMPLATE_PATH,
1809
+ TMPFILE_VERSION,
1810
+ self.KERNELSZ**2,
1811
+ NORIENT,
1812
+ nside,
1813
+ )
1814
+ )
1815
+ else:
1816
+ tmp = indice
1817
+
1818
+ self.Idx_CNN[(nside,NORIENT,self.KERNELSZ)] = tmp
1819
+ self.Idx_WCNN[(nside,NORIENT,self.KERNELSZ)] = self.backend.bk_cast(wav)
1820
+
1821
+ return wav, tmp
1822
+
1778
1823
  # ---------------------------------------------−---------
1779
1824
  # convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
1780
1825
  def swapaxes(self, x, axis1, axis2):