foscat 2025.8.3__py3-none-any.whl → 2025.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +241 -49
- foscat/FoCUS.py +5 -3
- foscat/HOrientedConvol.py +446 -42
- foscat/HealBili.py +305 -0
- foscat/Plot.py +328 -0
- foscat/UNET.py +455 -178
- foscat/healpix_unet_torch.py +717 -0
- foscat/scat_cov.py +42 -30
- {foscat-2025.8.3.dist-info → foscat-2025.9.1.dist-info}/METADATA +1 -1
- {foscat-2025.8.3.dist-info → foscat-2025.9.1.dist-info}/RECORD +13 -10
- {foscat-2025.8.3.dist-info → foscat-2025.9.1.dist-info}/WHEEL +0 -0
- {foscat-2025.8.3.dist-info → foscat-2025.9.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.8.3.dist-info → foscat-2025.9.1.dist-info}/top_level.txt +0 -0
foscat/HOrientedConvol.py
CHANGED
|
@@ -3,16 +3,36 @@ import matplotlib.pyplot as plt
|
|
|
3
3
|
import healpy as hp
|
|
4
4
|
from scipy.sparse import csr_array
|
|
5
5
|
import torch
|
|
6
|
+
import foscat.scat_cov as sc
|
|
6
7
|
from scipy.spatial import cKDTree
|
|
7
8
|
|
|
8
9
|
class HOrientedConvol:
|
|
9
|
-
def __init__(self,
|
|
10
|
+
def __init__(self,
|
|
11
|
+
nside,
|
|
12
|
+
KERNELSZ,
|
|
13
|
+
cell_ids=None,
|
|
14
|
+
nest=True,
|
|
15
|
+
device='cuda',
|
|
16
|
+
dtype='float64',
|
|
17
|
+
polar=False,
|
|
18
|
+
gamma=1.0,
|
|
19
|
+
allow_extrapolation=True,
|
|
20
|
+
no_cell_ids=False,
|
|
21
|
+
):
|
|
22
|
+
|
|
23
|
+
if dtype=='float64':
|
|
24
|
+
self.dtype=torch.float64
|
|
25
|
+
else:
|
|
26
|
+
self.dtype=torch.float32
|
|
10
27
|
|
|
11
28
|
if KERNELSZ % 2 == 0:
|
|
12
29
|
raise ValueError(f"N must be odd so that coordinates are integers from -K..K; got N={KERNELSZ}.")
|
|
13
30
|
|
|
14
31
|
self.local_test=False
|
|
15
|
-
|
|
32
|
+
|
|
33
|
+
if no_cell_ids==True:
|
|
34
|
+
cell_ids=np.arange(10)
|
|
35
|
+
|
|
16
36
|
if cell_ids is None:
|
|
17
37
|
self.cell_ids=np.arange(12*nside**2)
|
|
18
38
|
|
|
@@ -28,37 +48,84 @@ class HOrientedConvol:
|
|
|
28
48
|
self.cell_ids=cell_ids
|
|
29
49
|
|
|
30
50
|
self.local_test=True
|
|
31
|
-
|
|
32
|
-
idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
|
|
33
|
-
KERNELSZ*KERNELSZ,
|
|
34
|
-
nside,
|
|
35
|
-
nest=nest,
|
|
36
|
-
)
|
|
37
51
|
|
|
52
|
+
if self.cell_ids.ndim==1:
|
|
53
|
+
idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
|
|
54
|
+
KERNELSZ*KERNELSZ,
|
|
55
|
+
nside,
|
|
56
|
+
nest=nest,
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
idx_nn = []
|
|
60
|
+
for k in range(self.cell_ids.shape[0]):
|
|
61
|
+
idx_nn.append(self.knn_healpix_ckdtree(self.cell_ids[k],
|
|
62
|
+
KERNELSZ*KERNELSZ,
|
|
63
|
+
nside,
|
|
64
|
+
nest=nest,
|
|
65
|
+
))
|
|
66
|
+
idx_nn=np.stack(idx_nn,0)
|
|
67
|
+
|
|
68
|
+
if self.cell_ids.ndim==1:
|
|
69
|
+
mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids,nest=nest)
|
|
38
70
|
|
|
39
|
-
|
|
71
|
+
if self.local_test:
|
|
72
|
+
t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
|
|
73
|
+
else:
|
|
74
|
+
t,p = hp.pix2ang(nside,idx_nn,nest=True)
|
|
75
|
+
|
|
76
|
+
self.t=t[:,0]
|
|
77
|
+
self.p=p[:,0]
|
|
78
|
+
vec_orig=hp.ang2vec(t,p)
|
|
79
|
+
|
|
80
|
+
self.vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
|
|
40
81
|
|
|
41
|
-
|
|
42
|
-
|
|
82
|
+
'''
|
|
83
|
+
if self.local_test:
|
|
84
|
+
idx_nn=self.remap_by_first_column(idx_nn)
|
|
85
|
+
'''
|
|
86
|
+
|
|
87
|
+
del mat_pt
|
|
88
|
+
del vec_orig
|
|
43
89
|
else:
|
|
44
|
-
t,p = hp.pix2ang(nside,idx_nn,nest=True)
|
|
45
90
|
|
|
46
|
-
|
|
91
|
+
t,p,vec_rot = [],[],[]
|
|
92
|
+
|
|
93
|
+
for k in range(self.cell_ids.shape[0]):
|
|
94
|
+
mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids[k],nest=nest)
|
|
95
|
+
|
|
96
|
+
lt,lp = hp.pix2ang(nside,self.cell_ids[k,idx_nn[k]],nest=True)
|
|
97
|
+
|
|
98
|
+
vec_orig=hp.ang2vec(lt,lp)
|
|
99
|
+
|
|
100
|
+
l_vec_rot=np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
|
|
101
|
+
vec_rot.append(l_vec_rot)
|
|
102
|
+
|
|
103
|
+
del vec_orig
|
|
104
|
+
del mat_pt
|
|
105
|
+
|
|
106
|
+
t.append(lt[:,0])
|
|
107
|
+
p.append(lp[:,0])
|
|
47
108
|
|
|
48
|
-
|
|
109
|
+
|
|
110
|
+
self.t=np.stack(t,0)
|
|
111
|
+
self.p=np.stack(p,0)
|
|
112
|
+
self.vec_rot=np.stack(vec_rot,0)
|
|
49
113
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
114
|
+
del t
|
|
115
|
+
del p
|
|
116
|
+
del vec_rot
|
|
117
|
+
|
|
118
|
+
self.polar=polar
|
|
119
|
+
self.gamma=gamma
|
|
120
|
+
self.device=device
|
|
121
|
+
self.allow_extrapolation=allow_extrapolation
|
|
122
|
+
self.w_idx=None
|
|
54
123
|
|
|
55
|
-
del mat_pt
|
|
56
|
-
del vec_orig
|
|
57
|
-
self.t=t[:,0]
|
|
58
|
-
self.p=p[:,0]
|
|
59
124
|
self.idx_nn=idx_nn
|
|
60
125
|
self.nside=nside
|
|
61
126
|
self.KERNELSZ=KERNELSZ
|
|
127
|
+
self.nest=nest
|
|
128
|
+
self.f=None
|
|
62
129
|
|
|
63
130
|
def remap_by_first_column(self,idx: np.ndarray) -> np.ndarray:
|
|
64
131
|
"""
|
|
@@ -290,25 +357,95 @@ class HOrientedConvol:
|
|
|
290
357
|
|
|
291
358
|
return csr_array((w, (indice_1_0, indice_1_1)), shape=(12*self.nside**2, 12*self.nside**2*NORIENT))
|
|
292
359
|
|
|
293
|
-
|
|
294
|
-
|
|
360
|
+
def make_idx_weights_from_cell_ids(self,i_cell_ids,
|
|
361
|
+
polar=False,
|
|
362
|
+
gamma=1.0,
|
|
363
|
+
device='cuda',
|
|
364
|
+
allow_extrapolation=True):
|
|
365
|
+
if len(i_cell_ids.shape)<2:
|
|
366
|
+
cell_ids=i_cell_ids
|
|
367
|
+
n_cids=1
|
|
368
|
+
else:
|
|
369
|
+
cell_ids=i_cell_ids[0]
|
|
370
|
+
n_cids=i_cell_ids.shape[0]
|
|
371
|
+
|
|
372
|
+
idx_nn,w_idx,w_w = [],[],[]
|
|
373
|
+
|
|
374
|
+
for k in range(n_cids):
|
|
375
|
+
cell_ids=i_cell_ids[k]
|
|
376
|
+
l_idx_nn,l_w_idx,l_w_w = self.make_idx_weights_from_one_cell_ids(cell_ids,
|
|
377
|
+
polar=polar,
|
|
378
|
+
gamma=gamma,
|
|
379
|
+
device=device,
|
|
380
|
+
allow_extrapolation=allow_extrapolation)
|
|
381
|
+
idx_nn.append(l_idx_nn)
|
|
382
|
+
w_idx.append(l_w_idx)
|
|
383
|
+
w_w.append(l_w_w)
|
|
384
|
+
|
|
385
|
+
idx_nn = torch.Tensor(np.stack(idx_nn,0)).to(device=device, dtype=torch.long)
|
|
386
|
+
w_idx = torch.Tensor(np.stack(w_idx,0)).to(device=device, dtype=torch.long)
|
|
387
|
+
w_w = torch.Tensor(np.stack(w_w,0)).to(device=device, dtype=self.dtype)
|
|
388
|
+
|
|
389
|
+
return idx_nn,w_idx,w_w
|
|
390
|
+
|
|
391
|
+
def make_idx_weights_from_one_cell_ids(self,
|
|
392
|
+
cell_ids,
|
|
393
|
+
polar=False,
|
|
394
|
+
gamma=1.0,
|
|
395
|
+
device='cuda',
|
|
396
|
+
allow_extrapolation=True):
|
|
397
|
+
|
|
398
|
+
idx_nn = self.knn_healpix_ckdtree(cell_ids,
|
|
399
|
+
self.KERNELSZ*self.KERNELSZ,
|
|
400
|
+
self.nside,
|
|
401
|
+
nest=self.nest,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
mat_pt=self.rotation_matrices_from_healpix(self.nside,cell_ids,nest=self.nest)
|
|
405
|
+
|
|
406
|
+
t,p = hp.pix2ang(self.nside,cell_ids[idx_nn],nest=self.nest)
|
|
407
|
+
|
|
408
|
+
vec_orig=hp.ang2vec(t,p)
|
|
409
|
+
|
|
410
|
+
vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
|
|
411
|
+
|
|
412
|
+
del vec_orig
|
|
413
|
+
del mat_pt
|
|
295
414
|
|
|
296
|
-
rotate=2*((
|
|
415
|
+
rotate=2*((t<np.pi/2)-0.5)[:,None]
|
|
297
416
|
if polar:
|
|
298
|
-
xx=np.cos(
|
|
299
|
-
yy=-np.sin(
|
|
417
|
+
xx=np.cos(p)[:,None]*vec_rot[:,:,0]-rotate*np.sin(p)[:,None]*vec_rot[:,:,1]
|
|
418
|
+
yy=-np.sin(p)[:,None]*vec_rot[:,:,0]-rotate*np.cos(p)[:,None]*vec_rot[:,:,1]
|
|
300
419
|
else:
|
|
301
|
-
xx=
|
|
302
|
-
yy=
|
|
420
|
+
xx=vec_rot[:,:,0]
|
|
421
|
+
yy=vec_rot[:,:,1]
|
|
422
|
+
|
|
423
|
+
del vec_rot
|
|
424
|
+
del rotate
|
|
425
|
+
del t
|
|
426
|
+
del p
|
|
303
427
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
428
|
+
w_idx,w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
|
|
429
|
+
yy*self.nside*gamma,
|
|
430
|
+
allow_extrapolation=allow_extrapolation)
|
|
307
431
|
|
|
432
|
+
del xx
|
|
433
|
+
del yy
|
|
434
|
+
|
|
435
|
+
return idx_nn,w_idx,w_w
|
|
436
|
+
|
|
437
|
+
def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True,return_index=False):
|
|
438
|
+
|
|
439
|
+
idx_nn,w_idx,w_w = self.make_idx_weights_from_one_cell_ids(self.cell_ids,
|
|
440
|
+
polar=polar,
|
|
441
|
+
gamma=gamma,
|
|
442
|
+
device=device,
|
|
443
|
+
allow_extrapolation=allow_extrapolation)
|
|
444
|
+
|
|
308
445
|
# Ensure types/devices
|
|
309
|
-
self.idx_nn = torch.Tensor(
|
|
310
|
-
self.w_idx = torch.Tensor(
|
|
311
|
-
self.w_w = torch.Tensor(
|
|
446
|
+
self.idx_nn = torch.Tensor(idx_nn).to(device=device, dtype=torch.long)
|
|
447
|
+
self.w_idx = torch.Tensor(w_idx).to(device=device, dtype=torch.long)
|
|
448
|
+
self.w_w = torch.Tensor(w_w).to(device=device, dtype=self.dtype)
|
|
312
449
|
|
|
313
450
|
def _grid_index(self, xi, yi):
|
|
314
451
|
"""
|
|
@@ -411,7 +548,184 @@ class HOrientedConvol:
|
|
|
411
548
|
|
|
412
549
|
return idx, w
|
|
413
550
|
|
|
414
|
-
def Convol_torch(self, im, ww):
|
|
551
|
+
def Convol_torch(self, im, ww, cell_ids=None, nside=None):
|
|
552
|
+
"""
|
|
553
|
+
Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
|
|
554
|
+
|
|
555
|
+
Parameters
|
|
556
|
+
----------
|
|
557
|
+
im : Tensor, shape (B, C_i, Npix)
|
|
558
|
+
ww : Tensor, shapes supported:
|
|
559
|
+
(C_i, C_o, M) | (C_i, C_o, M, S) | (B, C_i, C_o, M) | (B, C_i, C_o, M, S)
|
|
560
|
+
cell_ids : ndarray or Tensor
|
|
561
|
+
- None: use precomputed self.idx_nn / self.w_idx / self.w_w (shared for batch).
|
|
562
|
+
- (Npix,): recompute once (shared for batch).
|
|
563
|
+
- (B, Npix): recompute per-sample (different for each b).
|
|
564
|
+
|
|
565
|
+
Returns
|
|
566
|
+
-------
|
|
567
|
+
out : Tensor, shape (B, C_o, Npix)
|
|
568
|
+
"""
|
|
569
|
+
import torch
|
|
570
|
+
|
|
571
|
+
# ---- Basic checks / casting ----
|
|
572
|
+
if not isinstance(im, torch.Tensor):
|
|
573
|
+
im = torch.as_tensor(im, device=self.device, dtype=self.dtype)
|
|
574
|
+
if not isinstance(ww, torch.Tensor):
|
|
575
|
+
ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
|
|
576
|
+
|
|
577
|
+
assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
|
|
578
|
+
B, C_i, Npix = im.shape
|
|
579
|
+
device = im.device
|
|
580
|
+
dtype = im.dtype
|
|
581
|
+
|
|
582
|
+
# ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
|
|
583
|
+
# target shapes:
|
|
584
|
+
# idx_nn_eff : (B, Npix, P)
|
|
585
|
+
# w_idx_eff : (B, Npix, S, P)
|
|
586
|
+
# w_w_eff : (B, Npix, S, P)
|
|
587
|
+
if cell_ids is not None:
|
|
588
|
+
# to numpy for your make_idx_weights_from_cell_ids helper if needed
|
|
589
|
+
if isinstance(cell_ids, torch.Tensor):
|
|
590
|
+
cid = cell_ids.detach().to("cpu").numpy()
|
|
591
|
+
else:
|
|
592
|
+
cid = cell_ids
|
|
593
|
+
|
|
594
|
+
if cid.ndim == 1:
|
|
595
|
+
# single set of ids for the whole batch
|
|
596
|
+
idx_nn, w_idx, w_w = self.make_idx_weights_from_cell_ids(cid, nside, device=device)
|
|
597
|
+
assert idx_nn.ndim == 2, "idx_nn expected (Npix,P)"
|
|
598
|
+
P = idx_nn.shape[1]
|
|
599
|
+
if w_idx.ndim == 2:
|
|
600
|
+
# (Npix,P) -> (B,Npix,1,P)
|
|
601
|
+
S = 1
|
|
602
|
+
w_idx_eff = w_idx[None, :, None, :].expand(B, -1, -1, -1)
|
|
603
|
+
w_w_eff = w_w[None, :, None, :].expand(B, -1, -1, -1)
|
|
604
|
+
elif w_idx.ndim == 3:
|
|
605
|
+
# (Npix,S,P) -> (B,Npix,S,P)
|
|
606
|
+
S = w_idx.shape[1]
|
|
607
|
+
w_idx_eff = w_idx[None, ...].expand(B, -1, -1, -1)
|
|
608
|
+
w_w_eff = w_w[None, ...].expand(B, -1, -1, -1)
|
|
609
|
+
else:
|
|
610
|
+
raise ValueError(f"Unsupported w_idx shape {tuple(w_idx.shape)}")
|
|
611
|
+
idx_nn_eff = idx_nn[None, ...].expand(B, -1, -1) # (B,Npix,P)
|
|
612
|
+
|
|
613
|
+
elif cid.ndim == 2:
|
|
614
|
+
# per-sample ids
|
|
615
|
+
assert cid.shape[0] == B and cid.shape[1] == Npix, \
|
|
616
|
+
f"cell_ids must be (B,Npix) with B={B},Npix={Npix}, got {cid.shape}"
|
|
617
|
+
S_ref = None
|
|
618
|
+
|
|
619
|
+
idx_nn_eff, w_idx_eff, w_w_eff = self.make_idx_weights_from_cell_ids(cid,
|
|
620
|
+
nside,
|
|
621
|
+
device=device)
|
|
622
|
+
P = idx_nn_eff.shape[-1]
|
|
623
|
+
S = w_idx_eff.shape[-2]
|
|
624
|
+
|
|
625
|
+
else:
|
|
626
|
+
raise ValueError(f"Unsupported cell_ids shape {cid.shape}")
|
|
627
|
+
|
|
628
|
+
# ensure tensors on right device/dtype
|
|
629
|
+
idx_nn_eff = torch.as_tensor(idx_nn_eff, device=device, dtype=torch.long)
|
|
630
|
+
w_idx_eff = torch.as_tensor(w_idx_eff, device=device, dtype=torch.long)
|
|
631
|
+
w_w_eff = torch.as_tensor(w_w_eff, device=device, dtype=dtype)
|
|
632
|
+
|
|
633
|
+
else:
|
|
634
|
+
# Use precomputed (shared for batch)
|
|
635
|
+
if self.w_idx is None:
|
|
636
|
+
|
|
637
|
+
if self.cell_ids.ndim==1:
|
|
638
|
+
l_cell=self.cell_ids[None,:]
|
|
639
|
+
else:
|
|
640
|
+
l_cell=self.cell_ids
|
|
641
|
+
|
|
642
|
+
idx_nn,w_idx,w_w = self.make_idx_weights_from_cell_ids(
|
|
643
|
+
l_cell,
|
|
644
|
+
polar=self.polar,
|
|
645
|
+
gamma=self.gamma,
|
|
646
|
+
device=self.device,
|
|
647
|
+
allow_extrapolation=self.allow_extrapolation)
|
|
648
|
+
|
|
649
|
+
self.idx_nn = idx_nn
|
|
650
|
+
self.w_idx = w_idx
|
|
651
|
+
self.w_w = w_w
|
|
652
|
+
else:
|
|
653
|
+
idx_nn = self.idx_nn # (Npix,P)
|
|
654
|
+
w_idx = self.w_idx # (Npix,P) or (Npix,S,P)
|
|
655
|
+
w_w = self.w_w # (Npix,P) or (Npix,S,P)
|
|
656
|
+
|
|
657
|
+
#assert idx_nn.ndim == 3 and idx_nn.size(1) == Npix, \
|
|
658
|
+
# f"`idx_nn` must be (B,Npix,P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
|
|
659
|
+
|
|
660
|
+
P = idx_nn.size(-1)
|
|
661
|
+
|
|
662
|
+
if w_idx.ndim == 3:
|
|
663
|
+
S = 1
|
|
664
|
+
w_idx_eff = w_idx[:, :, None, :] # (B,Npix,1,P)
|
|
665
|
+
w_w_eff = w_w[:, :, None, :] # (B,Npix,1,P)
|
|
666
|
+
elif w_idx.ndim == 4:
|
|
667
|
+
S = w_idx.size(2)
|
|
668
|
+
w_idx_eff = w_idx # (B,Npix,S,P)
|
|
669
|
+
w_w_eff = w_w # (B,Npix,S,P)
|
|
670
|
+
else:
|
|
671
|
+
raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
|
|
672
|
+
idx_nn_eff = idx_nn # (B,Npix,P)
|
|
673
|
+
|
|
674
|
+
# ---- 1) Gather neighbor values from im along Npix -> (B, C_i, Npix, P)
|
|
675
|
+
rim = torch.take_along_dim(
|
|
676
|
+
im.unsqueeze(-1), # (B, C_i, Npix, 1)
|
|
677
|
+
idx_nn_eff[:, None, :, :], # (B, 1, Npix, P)
|
|
678
|
+
dim=2
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# ---- 2) Normalize ww to (B, C_i, C_o, M, S)
|
|
682
|
+
if ww.ndim == 3:
|
|
683
|
+
C_i_w, C_o, M = ww.shape
|
|
684
|
+
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
685
|
+
ww_eff = ww.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, -1, S)
|
|
686
|
+
elif ww.ndim == 4:
|
|
687
|
+
if ww.shape[0] == C_i and ww.shape[1] != C_i:
|
|
688
|
+
# (C_i, C_o, M, S)
|
|
689
|
+
C_i_w, C_o, M, S_w = ww.shape
|
|
690
|
+
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
691
|
+
assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
|
|
692
|
+
ww_eff = ww.unsqueeze(0).expand(B, -1, -1, -1, -1)
|
|
693
|
+
elif ww.shape[0] == B:
|
|
694
|
+
# (B, C_i, C_o, M)
|
|
695
|
+
_, C_i_w, C_o, M = ww.shape
|
|
696
|
+
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
697
|
+
ww_eff = ww.unsqueeze(-1).expand(-1, -1, -1, -1, S)
|
|
698
|
+
else:
|
|
699
|
+
raise ValueError(f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)")
|
|
700
|
+
elif ww.ndim == 5:
|
|
701
|
+
# (B, C_i, C_o, M, S)
|
|
702
|
+
assert ww.shape[0] == B and ww.shape[1] == C_i, "ww batch/C_i mismatch"
|
|
703
|
+
_, _, _, M, S_w = ww.shape
|
|
704
|
+
assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
|
|
705
|
+
ww_eff = ww
|
|
706
|
+
else:
|
|
707
|
+
raise ValueError(f"Unsupported ww shape {tuple(ww.shape)}")
|
|
708
|
+
|
|
709
|
+
# --- Sanitize shapes: ensure w_idx_eff / w_w_eff == (B, Npix, S, P)
|
|
710
|
+
|
|
711
|
+
# ---- 3) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
|
|
712
|
+
idx_exp = w_idx_eff[:, None, None, :, :, :] # (B,1,1,Npix,S,P)
|
|
713
|
+
rw = torch.take_along_dim(
|
|
714
|
+
ww_eff.unsqueeze(-1), # (B,C_i,C_o,M,S,1)
|
|
715
|
+
idx_exp, # (B,1,1,Npix,S,P)
|
|
716
|
+
dim=3 # gather along M
|
|
717
|
+
) # -> (B, C_i, C_o, Npix, S, P)
|
|
718
|
+
# ---- 4) Apply extra neighbor weights ----
|
|
719
|
+
rw = rw * w_w_eff[:, None, None, :, :, :] # (B, C_i, C_o, Npix, S, P)
|
|
720
|
+
# ---- 5) Combine neighbor values and weights ----
|
|
721
|
+
rim_exp = rim[:, :, None, :, None, :] # (B, C_i, 1, Npix, 1, P)
|
|
722
|
+
out_ci = (rim_exp * rw).sum(dim=-1) # sum over P -> (B, C_i, C_o, Npix, S)
|
|
723
|
+
out_ci = out_ci.sum(dim=-1) # sum over S -> (B, C_i, C_o, Npix)
|
|
724
|
+
out = out_ci.sum(dim=1) # sum over C_i -> (B, C_o, Npix)
|
|
725
|
+
|
|
726
|
+
return out
|
|
727
|
+
|
|
728
|
+
def Convol_torch_old(self, im, ww,cell_ids=None,nside=None):
|
|
415
729
|
"""
|
|
416
730
|
Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
|
|
417
731
|
|
|
@@ -426,7 +740,11 @@ class HOrientedConvol:
|
|
|
426
740
|
(C_i, C_o, M, S)
|
|
427
741
|
(B, C_i, C_o, M)
|
|
428
742
|
(B, C_i, C_o, M, S)
|
|
429
|
-
|
|
743
|
+
|
|
744
|
+
cell_ids : ndarray
|
|
745
|
+
If cell_ids is not None recompute the index and do not use the precomputed ones.
|
|
746
|
+
Note : The computation is then much longer.
|
|
747
|
+
|
|
430
748
|
Class members (already tensors; will be aligned to im.device/dtype):
|
|
431
749
|
-------------------------------------------------------------------
|
|
432
750
|
self.idx_nn : LongTensor, shape (Npix, P)
|
|
@@ -443,18 +761,27 @@ class HOrientedConvol:
|
|
|
443
761
|
Aggregated output per center pixel for each batch sample.
|
|
444
762
|
"""
|
|
445
763
|
# ---- Basic checks ----
|
|
764
|
+
if not isinstance(im,torch.Tensor):
|
|
765
|
+
im=torch.Tensor(im).to(device=self.device, dtype=self.dtype)
|
|
766
|
+
if not isinstance(ww,torch.Tensor):
|
|
767
|
+
ww=torch.Tensor(ww).to(device=self.device, dtype=self.dtype)
|
|
768
|
+
|
|
446
769
|
assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
|
|
770
|
+
|
|
447
771
|
assert ww.shape[2]==self.KERNELSZ*self.KERNELSZ, f"`ww` must be (C_i, C_o, KERNELSZ*KERNELSZ), got {tuple(ww.shape)}"
|
|
448
772
|
|
|
449
773
|
B, C_i, Npix = im.shape
|
|
450
774
|
device = im.device
|
|
451
775
|
dtype = im.dtype
|
|
776
|
+
|
|
777
|
+
if cell_ids is not None:
|
|
452
778
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
779
|
+
idx_nn,w_idx,w_w = self.make_idx_weights_from_cell_ids(cell_ids,nside,device=device)
|
|
780
|
+
else:
|
|
781
|
+
idx_nn = self.idx_nn # (Npix, P)
|
|
782
|
+
w_idx = self.w_idx # (Npix, P) or (Npix, S, P)
|
|
783
|
+
w_w = self.w_w # (Npix, P) or (Npix, S, P)
|
|
784
|
+
|
|
458
785
|
# Neighbor count P inferred from idx_nn
|
|
459
786
|
assert idx_nn.ndim == 2 and idx_nn.size(0) == Npix, \
|
|
460
787
|
f"`idx_nn` must be (Npix, P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
|
|
@@ -542,5 +869,82 @@ class HOrientedConvol:
|
|
|
542
869
|
|
|
543
870
|
return out
|
|
544
871
|
|
|
872
|
+
def Down(self, im, cell_ids=None,nside=None):
|
|
873
|
+
if self.f is None:
|
|
874
|
+
if self.dtype==torch.float64:
|
|
875
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
876
|
+
else:
|
|
877
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
878
|
+
|
|
879
|
+
if cell_ids is None:
|
|
880
|
+
dim,_ = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside)
|
|
881
|
+
return dim
|
|
882
|
+
else:
|
|
883
|
+
if nside is None:
|
|
884
|
+
nside=self.nside
|
|
885
|
+
if len(cell_ids.shape)==1:
|
|
886
|
+
return self.f.ud_grade_2(im,cell_ids=cell_ids,nside=nside)
|
|
887
|
+
else:
|
|
888
|
+
assert im.shape[0] == cell_ids.shape[0], \
|
|
889
|
+
f"cell_ids and data should have the same batch size (first column), got data={im.shape},cell_ids={cell_ids.shape}"
|
|
890
|
+
|
|
891
|
+
result,result_cell_ids = [],[]
|
|
892
|
+
|
|
893
|
+
for k in range(im.shape[0]):
|
|
894
|
+
r,c = self.f.ud_grade_2(im[k],cell_ids=cell_ids[k],nside=nside)
|
|
895
|
+
result.append(r)
|
|
896
|
+
result_cell_ids.append(c)
|
|
897
|
+
|
|
898
|
+
result = torch.stack(result, dim=0) # (B,...,Npix)
|
|
899
|
+
result_cell_ids = torch.stack(result_cell_ids, dim=0) # (B,Npix)
|
|
900
|
+
return result,result_cell_ids
|
|
901
|
+
|
|
902
|
+
def Up(self, im, cell_ids=None,nside=None,o_cell_ids=None):
|
|
903
|
+
if self.f is None:
|
|
904
|
+
if self.dtype==torch.float64:
|
|
905
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
906
|
+
else:
|
|
907
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
908
|
+
|
|
909
|
+
if cell_ids is None:
|
|
910
|
+
dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
|
|
911
|
+
return dim
|
|
912
|
+
else:
|
|
913
|
+
if nside is None:
|
|
914
|
+
nside=self.nside
|
|
915
|
+
if nside is None:
|
|
916
|
+
nside=self.nside
|
|
917
|
+
if len(cell_ids.shape)==1:
|
|
918
|
+
return self.f.up_grade(im,nside*2,cell_ids=cell_ids,nside=nside,o_cell_ids=o_cell_ids)
|
|
919
|
+
else:
|
|
920
|
+
assert im.shape[0] == cell_ids.shape[0], \
|
|
921
|
+
f"cell_ids and data should have the same batch size (first column), got data={im.shape},cell_ids={cell_ids.shape}"
|
|
922
|
+
|
|
923
|
+
assert im.shape[0] == o_cell_ids.shape[0], \
|
|
924
|
+
f"cell_ids and data should have the same batch size (first column), got data={im.shape},o_cell_ids={o_cell_ids.shape}"
|
|
925
|
+
|
|
926
|
+
result = []
|
|
927
|
+
|
|
928
|
+
for k in range(im.shape[0]):
|
|
929
|
+
r= self.f.up_grade(im[k],nside*2,cell_ids=cell_ids[k],nside=nside,o_cell_ids=o_cell_ids[k])
|
|
930
|
+
result.append(r)
|
|
931
|
+
|
|
932
|
+
result = torch.stack(result, dim=0) # (B,...,Npix)
|
|
933
|
+
return result
|
|
934
|
+
|
|
935
|
+
def to_tensor(self,x):
|
|
936
|
+
if self.f is None:
|
|
937
|
+
if self.dtype==torch.float64:
|
|
938
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
939
|
+
else:
|
|
940
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
941
|
+
return self.f.backend.bk_cast(x)
|
|
942
|
+
|
|
943
|
+
def to_numpy(self,x):
|
|
944
|
+
if isinstance(x,np.ndarray):
|
|
945
|
+
return x
|
|
946
|
+
return x.cpu().numpy()
|
|
947
|
+
|
|
948
|
+
|
|
545
949
|
|
|
546
950
|
|