foscat 2025.9.1__py3-none-any.whl → 2025.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +160 -93
- foscat/FoCUS.py +80 -267
- foscat/HOrientedConvol.py +233 -250
- foscat/HealBili.py +12 -8
- foscat/Plot.py +1112 -142
- foscat/SphericalStencil.py +1346 -0
- foscat/UNET.py +21 -7
- foscat/healpix_unet_torch.py +656 -171
- foscat/scat_cov.py +2 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/METADATA +1 -1
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/RECORD +14 -13
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/WHEEL +0 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/top_level.txt +0 -0
foscat/HOrientedConvol.py
CHANGED
|
@@ -20,6 +20,7 @@ class HOrientedConvol:
|
|
|
20
20
|
no_cell_ids=False,
|
|
21
21
|
):
|
|
22
22
|
|
|
23
|
+
|
|
23
24
|
if dtype=='float64':
|
|
24
25
|
self.dtype=torch.float64
|
|
25
26
|
else:
|
|
@@ -72,7 +73,7 @@ class HOrientedConvol:
|
|
|
72
73
|
t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
|
|
73
74
|
else:
|
|
74
75
|
t,p = hp.pix2ang(nside,idx_nn,nest=True)
|
|
75
|
-
|
|
76
|
+
|
|
76
77
|
self.t=t[:,0]
|
|
77
78
|
self.p=p[:,0]
|
|
78
79
|
vec_orig=hp.ang2vec(t,p)
|
|
@@ -357,6 +358,47 @@ class HOrientedConvol:
|
|
|
357
358
|
|
|
358
359
|
return csr_array((w, (indice_1_0, indice_1_1)), shape=(12*self.nside**2, 12*self.nside**2*NORIENT))
|
|
359
360
|
|
|
361
|
+
def make_idx_weights_from_cell_ids(self,
|
|
362
|
+
i_cell_ids,
|
|
363
|
+
polar=False,
|
|
364
|
+
gamma=1.0,
|
|
365
|
+
device='cuda',
|
|
366
|
+
allow_extrapolation=True):
|
|
367
|
+
"""
|
|
368
|
+
Accept 1D (Npix,) or 2D (B, Npix) cell_ids and return
|
|
369
|
+
tensors batched sur la 1ère dim (B, ...).
|
|
370
|
+
"""
|
|
371
|
+
# → cast numpy
|
|
372
|
+
if torch.is_tensor(i_cell_ids):
|
|
373
|
+
cid = i_cell_ids.detach().cpu().numpy()
|
|
374
|
+
else:
|
|
375
|
+
cid = np.asarray(i_cell_ids)
|
|
376
|
+
|
|
377
|
+
# --- 1D: pas de boucle, on calcule une fois, puis on ajoute l'axe batch
|
|
378
|
+
if cid.ndim == 1:
|
|
379
|
+
l_idx_nn, l_w_idx, l_w_w = self.make_idx_weights_from_one_cell_ids(
|
|
380
|
+
cid, polar=polar, gamma=gamma, device=device,
|
|
381
|
+
allow_extrapolation=allow_extrapolation
|
|
382
|
+
)
|
|
383
|
+
idx_nn = torch.as_tensor(l_idx_nn, device=device, dtype=torch.long)[None, ...] # (1, Npix, P)
|
|
384
|
+
w_idx = torch.as_tensor(l_w_idx, device=device, dtype=torch.long)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
|
|
385
|
+
w_w = torch.as_tensor(l_w_w, device=device, dtype=self.dtype)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
|
|
386
|
+
return idx_nn, w_idx, w_w
|
|
387
|
+
|
|
388
|
+
# --- 2D: boucle sur b, empilement en (B, ...)
|
|
389
|
+
elif cid.ndim == 2:
|
|
390
|
+
outs = [ self.make_idx_weights_from_one_cell_ids(
|
|
391
|
+
cid[k], polar=polar, gamma=gamma, device=device,
|
|
392
|
+
allow_extrapolation=allow_extrapolation)
|
|
393
|
+
for k in range(cid.shape[0]) ]
|
|
394
|
+
idx_nn = torch.as_tensor(np.stack([o[0] for o in outs], axis=0), device=device, dtype=torch.long)
|
|
395
|
+
w_idx = torch.as_tensor(np.stack([o[1] for o in outs], axis=0), device=device, dtype=torch.long)
|
|
396
|
+
w_w = torch.as_tensor(np.stack([o[2] for o in outs], axis=0), device=device, dtype=self.dtype)
|
|
397
|
+
return idx_nn, w_idx, w_w
|
|
398
|
+
|
|
399
|
+
else:
|
|
400
|
+
raise ValueError(f"Unsupported cell_ids ndim={cid.ndim}; expected 1 or 2.")
|
|
401
|
+
'''
|
|
360
402
|
def make_idx_weights_from_cell_ids(self,i_cell_ids,
|
|
361
403
|
polar=False,
|
|
362
404
|
gamma=1.0,
|
|
@@ -387,7 +429,8 @@ class HOrientedConvol:
|
|
|
387
429
|
w_w = torch.Tensor(np.stack(w_w,0)).to(device=device, dtype=self.dtype)
|
|
388
430
|
|
|
389
431
|
return idx_nn,w_idx,w_w
|
|
390
|
-
|
|
432
|
+
'''
|
|
433
|
+
|
|
391
434
|
def make_idx_weights_from_one_cell_ids(self,
|
|
392
435
|
cell_ids,
|
|
393
436
|
polar=False,
|
|
@@ -428,11 +471,34 @@ class HOrientedConvol:
|
|
|
428
471
|
w_idx,w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
|
|
429
472
|
yy*self.nside*gamma,
|
|
430
473
|
allow_extrapolation=allow_extrapolation)
|
|
474
|
+
'''
|
|
475
|
+
# calib : [Npix, K]
|
|
476
|
+
calib = np.zeros((w_idx.shape[0], w_idx.shape[2]))
|
|
477
|
+
# Hypothèses :
|
|
478
|
+
# w_idx.shape == (Npix, M, K) et w_w.shape == (Npix, M, K)
|
|
479
|
+
Npix, M, K = w_idx.shape
|
|
480
|
+
nb_cols = K
|
|
481
|
+
|
|
482
|
+
# 1) Accumulation par "bincount" avec décalage de ligne
|
|
483
|
+
row_ids = np.arange(Npix, dtype=np.int64)[:, None, None] * nb_cols
|
|
484
|
+
flat_idx = (row_ids + w_idx).ravel() # indices dans [0, Npix*9)
|
|
485
|
+
weights = w_w.ravel().astype(np.float64) # ou dtype de ton choix
|
|
486
|
+
|
|
487
|
+
calib = np.bincount(flat_idx, weights, minlength=Npix*nb_cols)\
|
|
488
|
+
.reshape(Npix, nb_cols)
|
|
489
|
+
|
|
490
|
+
# 2) Réinjection dans norm_a selon w_idx
|
|
491
|
+
norm_a = calib[np.arange(Npix)[:, None, None], w_idx]
|
|
492
|
+
|
|
493
|
+
w_w /= norm_a
|
|
494
|
+
w_w = np.clip(w_w,0.0,1.0)
|
|
431
495
|
|
|
432
|
-
|
|
433
|
-
|
|
496
|
+
w_w[np.isnan(w_w)]=0.0
|
|
497
|
+
'''
|
|
498
|
+
#del xx
|
|
499
|
+
#del yy
|
|
434
500
|
|
|
435
|
-
return idx_nn,w_idx,w_w
|
|
501
|
+
return idx_nn,w_idx,w_w,xx,yy
|
|
436
502
|
|
|
437
503
|
def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True,return_index=False):
|
|
438
504
|
|
|
@@ -547,27 +613,58 @@ class HOrientedConvol:
|
|
|
547
613
|
idx = np.stack([i00, i10, i01, i11], axis=1).astype(np.int64)
|
|
548
614
|
|
|
549
615
|
return idx, w
|
|
616
|
+
|
|
617
|
+
# --- Add inside class HOrientedConvol, just above Convol_torch ---
|
|
618
|
+
def _convol_single(self, im1: torch.Tensor, ww: torch.Tensor, cell_ids=None, nside=None):
|
|
619
|
+
"""
|
|
620
|
+
Single-sample path. im1: (1, C_i, Npix_1). Returns (1, C_o, Npix_1).
|
|
621
|
+
"""
|
|
622
|
+
if not isinstance(im1, torch.Tensor):
|
|
623
|
+
im1 = torch.as_tensor(im1, device=self.device, dtype=self.dtype)
|
|
624
|
+
if not isinstance(ww, torch.Tensor):
|
|
625
|
+
ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
|
|
626
|
+
assert im1.ndim == 3 and im1.shape[0] == 1, f"expected (1, C_i, Npix), got {tuple(im1.shape)}"
|
|
627
|
+
|
|
628
|
+
# Reuse the existing Convol_torch core by faking B=1 shapes.
|
|
629
|
+
# We call the existing (batched) implementation with B=1.
|
|
630
|
+
return self.Convol_torch(im1, ww, cell_ids=cell_ids, nside=nside) # returns (1, C_o, Npix_1)
|
|
550
631
|
|
|
632
|
+
# --- Replace the first lines of Convol_torch with a dispatcher ---
|
|
551
633
|
def Convol_torch(self, im, ww, cell_ids=None, nside=None):
|
|
552
634
|
"""
|
|
553
|
-
Batched KERNELSZxKERNELSZ
|
|
635
|
+
Batched KERNELSZxKERNELSZ aggregation.
|
|
554
636
|
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
ww : Tensor, shapes supported:
|
|
559
|
-
(C_i, C_o, M) | (C_i, C_o, M, S) | (B, C_i, C_o, M) | (B, C_i, C_o, M, S)
|
|
560
|
-
cell_ids : ndarray or Tensor
|
|
561
|
-
- None: use precomputed self.idx_nn / self.w_idx / self.w_w (shared for batch).
|
|
562
|
-
- (Npix,): recompute once (shared for batch).
|
|
563
|
-
- (B, Npix): recompute per-sample (different for each b).
|
|
564
|
-
|
|
565
|
-
Returns
|
|
566
|
-
-------
|
|
567
|
-
out : Tensor, shape (B, C_o, Npix)
|
|
637
|
+
Accepts either:
|
|
638
|
+
- im: Tensor (B, C_i, Npix) with one shared or per-batch (B,Npix) cell_ids
|
|
639
|
+
- im: list/tuple of Tensors, each (C_i, Npix_b), with cell_ids a list of arrays
|
|
568
640
|
"""
|
|
569
641
|
import torch
|
|
570
642
|
|
|
643
|
+
# (A) Variable-length per-sample path: im is a list/tuple OR cell_ids is a list/tuple
|
|
644
|
+
if isinstance(im, (list, tuple)) or isinstance(cell_ids, (list, tuple)):
|
|
645
|
+
# Normalize to lists
|
|
646
|
+
im_list = im if isinstance(im, (list, tuple)) else [im]
|
|
647
|
+
cid_list = cell_ids if isinstance(cell_ids, (list, tuple)) else [cell_ids] * len(im_list)
|
|
648
|
+
assert len(im_list) == len(cid_list), "im list and cell_ids list must have same length"
|
|
649
|
+
|
|
650
|
+
outs = []
|
|
651
|
+
for xb, cb in zip(im_list, cid_list):
|
|
652
|
+
# xb: (C_i, Npix_b) -> (1, C_i, Npix_b)
|
|
653
|
+
if not torch.is_tensor(xb):
|
|
654
|
+
xb = torch.as_tensor(xb, device=self.device, dtype=self.dtype)
|
|
655
|
+
if xb.dim() == 2:
|
|
656
|
+
xb = xb.unsqueeze(0)
|
|
657
|
+
elif xb.dim() != 3 or xb.shape[0] != 1:
|
|
658
|
+
raise ValueError(f"Each sample must be (C,N) or (1,C,N); got {tuple(xb.shape)}")
|
|
659
|
+
|
|
660
|
+
yb = self._convol_single(xb, ww, cell_ids=cb, nside=nside) # (1, C_o, Npix_b)
|
|
661
|
+
outs.append(yb.squeeze(0)) # -> (C_o, Npix_b)
|
|
662
|
+
return outs # List[Tensor], each (C_o, Npix_b)
|
|
663
|
+
|
|
664
|
+
# (B) Standard fixed-length batched path (your current implementation)
|
|
665
|
+
# ... keep your existing Convol_torch body from here unchanged ...
|
|
666
|
+
# (paste your current function body starting from the type casting and assertions)
|
|
667
|
+
|
|
571
668
|
# ---- Basic checks / casting ----
|
|
572
669
|
if not isinstance(im, torch.Tensor):
|
|
573
670
|
im = torch.as_tensor(im, device=self.device, dtype=self.dtype)
|
|
@@ -585,55 +682,37 @@ class HOrientedConvol:
|
|
|
585
682
|
# w_idx_eff : (B, Npix, S, P)
|
|
586
683
|
# w_w_eff : (B, Npix, S, P)
|
|
587
684
|
if cell_ids is not None:
|
|
588
|
-
#
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
if cid.ndim == 1:
|
|
595
|
-
# single set of ids for the whole batch
|
|
596
|
-
idx_nn, w_idx, w_w = self.make_idx_weights_from_cell_ids(cid, nside, device=device)
|
|
597
|
-
assert idx_nn.ndim == 2, "idx_nn expected (Npix,P)"
|
|
598
|
-
P = idx_nn.shape[1]
|
|
599
|
-
if w_idx.ndim == 2:
|
|
600
|
-
# (Npix,P) -> (B,Npix,1,P)
|
|
601
|
-
S = 1
|
|
602
|
-
w_idx_eff = w_idx[None, :, None, :].expand(B, -1, -1, -1)
|
|
603
|
-
w_w_eff = w_w[None, :, None, :].expand(B, -1, -1, -1)
|
|
604
|
-
elif w_idx.ndim == 3:
|
|
605
|
-
# (Npix,S,P) -> (B,Npix,S,P)
|
|
606
|
-
S = w_idx.shape[1]
|
|
607
|
-
w_idx_eff = w_idx[None, ...].expand(B, -1, -1, -1)
|
|
608
|
-
w_w_eff = w_w[None, ...].expand(B, -1, -1, -1)
|
|
685
|
+
# ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
|
|
686
|
+
# Normaliser: accepter Tensor, ndarray ou list/tuple de 1 élément (cas var-length, B=1)
|
|
687
|
+
if isinstance(cell_ids, (list, tuple)):
|
|
688
|
+
# liste d'ids (souvent longueur 1 en var-length)
|
|
689
|
+
if len(cell_ids) == 1:
|
|
690
|
+
cid = np.asarray(cell_ids[0])[None, :] # -> (1, Npix)
|
|
609
691
|
else:
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
#
|
|
615
|
-
assert cid.shape[0] == B and cid.shape[1] == Npix, \
|
|
616
|
-
f"cell_ids must be (B,Npix) with B={B},Npix={Npix}, got {cid.shape}"
|
|
617
|
-
S_ref = None
|
|
618
|
-
|
|
619
|
-
idx_nn_eff, w_idx_eff, w_w_eff = self.make_idx_weights_from_cell_ids(cid,
|
|
620
|
-
nside,
|
|
621
|
-
device=device)
|
|
622
|
-
P = idx_nn_eff.shape[-1]
|
|
623
|
-
S = w_idx_eff.shape[-2]
|
|
624
|
-
|
|
692
|
+
# si jamais >1, on essaie d'empiler (doit avoir même Npix par élément)
|
|
693
|
+
cid = np.stack([np.asarray(c) for c in cell_ids], axis=0)
|
|
694
|
+
elif torch.is_tensor(cell_ids):
|
|
695
|
+
c = cell_ids.detach().cpu().numpy()
|
|
696
|
+
cid = c if c.ndim != 1 else c[None, :] # uniformiser en 2D quand B=1
|
|
625
697
|
else:
|
|
626
|
-
|
|
698
|
+
c = np.asarray(cell_ids)
|
|
699
|
+
cid = c if c.ndim != 1 else c[None, :]
|
|
627
700
|
|
|
628
|
-
#
|
|
701
|
+
# cid est maintenant (B, Npix)
|
|
702
|
+
idx_nn_eff, w_idx_eff, w_w_eff = self.make_idx_weights_from_cell_ids(
|
|
703
|
+
cid, nside, device=device
|
|
704
|
+
)
|
|
705
|
+
# shapes: (B, Npix, P), (B, Npix, S, P|P), (B, Npix, S, P|P)
|
|
706
|
+
P = idx_nn_eff.shape[-1]
|
|
707
|
+
S = w_idx_eff.shape[-2] if w_idx_eff.ndim == 4 else 1
|
|
708
|
+
|
|
709
|
+
# s’assurer des dtypes/devices
|
|
629
710
|
idx_nn_eff = torch.as_tensor(idx_nn_eff, device=device, dtype=torch.long)
|
|
630
711
|
w_idx_eff = torch.as_tensor(w_idx_eff, device=device, dtype=torch.long)
|
|
631
712
|
w_w_eff = torch.as_tensor(w_w_eff, device=device, dtype=dtype)
|
|
632
|
-
|
|
633
713
|
else:
|
|
634
714
|
# Use precomputed (shared for batch)
|
|
635
715
|
if self.w_idx is None:
|
|
636
|
-
|
|
637
716
|
if self.cell_ids.ndim==1:
|
|
638
717
|
l_cell=self.cell_ids[None,:]
|
|
639
718
|
else:
|
|
@@ -724,214 +803,118 @@ class HOrientedConvol:
|
|
|
724
803
|
out = out_ci.sum(dim=1) # sum over C_i -> (B, C_o, Npix)
|
|
725
804
|
|
|
726
805
|
return out
|
|
727
|
-
|
|
728
|
-
def Convol_torch_old(self, im, ww,cell_ids=None,nside=None):
|
|
729
|
-
"""
|
|
730
|
-
Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
|
|
731
806
|
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
Aggregated output per center pixel for each batch sample.
|
|
807
|
+
def _to_numpy_1d(self, ids):
|
|
808
|
+
"""Return a 1D numpy array of int64 for a single set of cell ids."""
|
|
809
|
+
import numpy as np, torch
|
|
810
|
+
if isinstance(ids, np.ndarray):
|
|
811
|
+
return ids.reshape(-1).astype(np.int64, copy=False)
|
|
812
|
+
if torch.is_tensor(ids):
|
|
813
|
+
return ids.detach().cpu().to(torch.long).view(-1).numpy()
|
|
814
|
+
# python list/tuple of ints
|
|
815
|
+
return np.asarray(ids, dtype=np.int64).reshape(-1)
|
|
816
|
+
|
|
817
|
+
def _is_varlength_batch(self, ids):
|
|
818
|
+
"""
|
|
819
|
+
True if ids is a list/tuple of per-sample id arrays (var-length batch).
|
|
820
|
+
False if ids is a single array/tensor of ids (shared for whole batch).
|
|
821
|
+
"""
|
|
822
|
+
import numpy as np, torch
|
|
823
|
+
if isinstance(ids, (list, tuple)):
|
|
824
|
+
return True
|
|
825
|
+
if isinstance(ids, np.ndarray) and ids.ndim == 2:
|
|
826
|
+
# This would be a dense (B, Npix) matrix -> NOT var-length list
|
|
827
|
+
return False
|
|
828
|
+
if torch.is_tensor(ids) and ids.dim() == 2:
|
|
829
|
+
return False
|
|
830
|
+
return False
|
|
831
|
+
|
|
832
|
+
def Down(self, im, cell_ids=None, nside=None,max_poll=False):
|
|
833
|
+
"""
|
|
834
|
+
If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
|
|
835
|
+
If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
|
|
762
836
|
"""
|
|
763
|
-
# ---- Basic checks ----
|
|
764
|
-
if not isinstance(im,torch.Tensor):
|
|
765
|
-
im=torch.Tensor(im).to(device=self.device, dtype=self.dtype)
|
|
766
|
-
if not isinstance(ww,torch.Tensor):
|
|
767
|
-
ww=torch.Tensor(ww).to(device=self.device, dtype=self.dtype)
|
|
768
|
-
|
|
769
|
-
assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
|
|
770
|
-
|
|
771
|
-
assert ww.shape[2]==self.KERNELSZ*self.KERNELSZ, f"`ww` must be (C_i, C_o, KERNELSZ*KERNELSZ), got {tuple(ww.shape)}"
|
|
772
|
-
|
|
773
|
-
B, C_i, Npix = im.shape
|
|
774
|
-
device = im.device
|
|
775
|
-
dtype = im.dtype
|
|
776
|
-
|
|
777
|
-
if cell_ids is not None:
|
|
778
|
-
|
|
779
|
-
idx_nn,w_idx,w_w = self.make_idx_weights_from_cell_ids(cell_ids,nside,device=device)
|
|
780
|
-
else:
|
|
781
|
-
idx_nn = self.idx_nn # (Npix, P)
|
|
782
|
-
w_idx = self.w_idx # (Npix, P) or (Npix, S, P)
|
|
783
|
-
w_w = self.w_w # (Npix, P) or (Npix, S, P)
|
|
784
|
-
|
|
785
|
-
# Neighbor count P inferred from idx_nn
|
|
786
|
-
assert idx_nn.ndim == 2 and idx_nn.size(0) == Npix, \
|
|
787
|
-
f"`idx_nn` must be (Npix, P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
|
|
788
|
-
P = idx_nn.size(1)
|
|
789
|
-
|
|
790
|
-
# ---- 1) Gather neighbor values from im along the Npix dimension -> (B, C_i, Npix, P)
|
|
791
|
-
# im: (B,C_i,Npix) -> (B,C_i,Npix,1); idx: (1,1,Npix,P) broadcast over (B,C_i)
|
|
792
|
-
rim = torch.take_along_dim(
|
|
793
|
-
im.unsqueeze(-1),
|
|
794
|
-
idx_nn.unsqueeze(0).unsqueeze(0),
|
|
795
|
-
dim=2
|
|
796
|
-
) # (B, C_i, Npix, P)
|
|
797
|
-
|
|
798
|
-
# ---- 2) Normalize w_idx / w_w to include a sector dim S ----
|
|
799
|
-
# Target layout: (Npix, S, P)
|
|
800
|
-
if w_idx.ndim == 2:
|
|
801
|
-
# (Npix, P) -> add sector dim S=1
|
|
802
|
-
assert w_idx.size(0) == Npix and w_idx.size(1) == P
|
|
803
|
-
w_idx_eff = w_idx.unsqueeze(1) # (Npix, 1, P)
|
|
804
|
-
w_w_eff = w_w.unsqueeze(1) # (Npix, 1, P)
|
|
805
|
-
S = 1
|
|
806
|
-
elif w_idx.ndim == 3:
|
|
807
|
-
# (Npix, S, P)
|
|
808
|
-
Npix_, S, P_ = w_idx.shape
|
|
809
|
-
assert Npix_ == Npix and P_ == P, \
|
|
810
|
-
f"`w_idx` must be (Npix,S,P) with Npix={Npix}, P={P}, got {tuple(w_idx.shape)}"
|
|
811
|
-
assert w_w.shape == w_idx.shape, "`w_w` must match `w_idx` shape"
|
|
812
|
-
w_idx_eff = w_idx
|
|
813
|
-
w_w_eff = w_w
|
|
814
|
-
else:
|
|
815
|
-
raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
|
|
816
|
-
|
|
817
|
-
# ---- 3) Normalize ww to (B, C_i, C_o, M, S) for uniform gather ----
|
|
818
|
-
if ww.ndim == 3:
|
|
819
|
-
# (C_i, C_o, M) -> (B, C_i, C_o, M, S)
|
|
820
|
-
C_i_w, C_o, M = ww.shape
|
|
821
|
-
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
822
|
-
ww_eff = ww.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, -1, S)
|
|
823
|
-
|
|
824
|
-
elif ww.ndim == 4:
|
|
825
|
-
# Could be (C_i, C_o, M, S) or (B, C_i, C_o, M)
|
|
826
|
-
if ww.shape[0] == C_i and ww.shape[1] != C_i:
|
|
827
|
-
# (C_i, C_o, M, S) -> (B, C_i, C_o, M, S)
|
|
828
|
-
C_i_w, C_o, M, S_w = ww.shape
|
|
829
|
-
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
830
|
-
assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
|
|
831
|
-
ww_eff = ww.unsqueeze(0).expand(B, -1, -1, -1, -1)
|
|
832
|
-
elif ww.shape[0] == B:
|
|
833
|
-
# (B, C_i, C_o, M) -> (B, C_i, C_o, M, S)
|
|
834
|
-
_, C_i_w, C_o, M = ww.shape
|
|
835
|
-
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
836
|
-
ww_eff = ww.unsqueeze(-1).expand(-1, -1, -1, -1, S)
|
|
837
|
-
else:
|
|
838
|
-
raise ValueError(
|
|
839
|
-
f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)"
|
|
840
|
-
)
|
|
841
|
-
|
|
842
|
-
elif ww.ndim == 5:
|
|
843
|
-
# (B, C_i, C_o, M, S)
|
|
844
|
-
assert ww.shape[0] == B and ww.shape[1] == C_i, "ww batch/C_i mismatch"
|
|
845
|
-
_, _, _, M, S_w = ww.shape
|
|
846
|
-
assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
|
|
847
|
-
ww_eff = ww
|
|
848
|
-
else:
|
|
849
|
-
raise ValueError(f"Unsupported ww shape {tuple(ww.shape)}")
|
|
850
|
-
|
|
851
|
-
# ---- 4) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
|
|
852
|
-
idx_exp = w_idx_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1,1,1,Npix,S,P)
|
|
853
|
-
rw = torch.take_along_dim(
|
|
854
|
-
ww_eff.unsqueeze(-1), # (B, C_i, C_o, M, S, 1)
|
|
855
|
-
idx_exp, # (1,1,1,Npix,S,P) -> broadcast
|
|
856
|
-
dim=3 # gather along M
|
|
857
|
-
) # -> (B, C_i, C_o, Npix, S, P)
|
|
858
|
-
|
|
859
|
-
# ---- 5) Apply extra neighbor weights ----
|
|
860
|
-
rw = rw * w_w_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (B, C_i, C_o, Npix, S, P)
|
|
861
|
-
|
|
862
|
-
# ---- 6) Combine neighbor values and weights ----
|
|
863
|
-
# rim: (B, C_i, Npix, P) -> expand to (B, C_i, 1, Npix, 1, P)
|
|
864
|
-
rim_exp = rim[:, :, None, :, None, :]
|
|
865
|
-
# sum over neighbors (P), then over sectors (S), then over input channels (C_i)
|
|
866
|
-
out_ci = (rim_exp * rw).sum(dim=-1) # (B, C_i, C_o, Npix, S)
|
|
867
|
-
out_ci = out_ci.sum(dim=-1) # (B, C_i, C_o, Npix)
|
|
868
|
-
out = out_ci.sum(dim=1) # (B, C_o, Npix)
|
|
869
|
-
|
|
870
|
-
return out
|
|
871
|
-
|
|
872
|
-
def Down(self, im, cell_ids=None,nside=None):
|
|
873
837
|
if self.f is None:
|
|
874
838
|
if self.dtype==torch.float64:
|
|
875
839
|
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
876
840
|
else:
|
|
877
841
|
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
878
|
-
|
|
879
|
-
if cell_ids is None:
|
|
880
|
-
dim,_ = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside)
|
|
881
|
-
return dim
|
|
882
|
-
else:
|
|
883
|
-
if nside is None:
|
|
884
|
-
nside=self.nside
|
|
885
|
-
if len(cell_ids.shape)==1:
|
|
886
|
-
return self.f.ud_grade_2(im,cell_ids=cell_ids,nside=nside)
|
|
887
|
-
else:
|
|
888
|
-
assert im.shape[0] == cell_ids.shape[0], \
|
|
889
|
-
f"cell_ids and data should have the same batch size (first column), got data={im.shape},cell_ids={cell_ids.shape}"
|
|
890
|
-
|
|
891
|
-
result,result_cell_ids = [],[]
|
|
892
|
-
|
|
893
|
-
for k in range(im.shape[0]):
|
|
894
|
-
r,c = self.f.ud_grade_2(im[k],cell_ids=cell_ids[k],nside=nside)
|
|
895
|
-
result.append(r)
|
|
896
|
-
result_cell_ids.append(c)
|
|
897
|
-
|
|
898
|
-
result = torch.stack(result, dim=0) # (B,...,Npix)
|
|
899
|
-
result_cell_ids = torch.stack(result_cell_ids, dim=0) # (B,Npix)
|
|
900
|
-
return result,result_cell_ids
|
|
901
842
|
|
|
902
|
-
|
|
843
|
+
if cell_ids is None:
|
|
844
|
+
dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=False)
|
|
845
|
+
return dim,cdim
|
|
846
|
+
|
|
847
|
+
if nside is None:
|
|
848
|
+
nside = self.nside
|
|
849
|
+
|
|
850
|
+
# var-length mode: list/tuple of ids, one per sample
|
|
851
|
+
if self._is_varlength_batch(cell_ids):
|
|
852
|
+
outs, outs_ids = [], []
|
|
853
|
+
B = len(cell_ids)
|
|
854
|
+
for b in range(B):
|
|
855
|
+
cid_b = self._to_numpy_1d(cell_ids[b])
|
|
856
|
+
# extraire le bon échantillon d'`im`
|
|
857
|
+
if torch.is_tensor(im):
|
|
858
|
+
xb = im[b:b+1] # (1, C, N_b)
|
|
859
|
+
yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
860
|
+
outs.append(yb.squeeze(0)) # (C, N_b')
|
|
861
|
+
else:
|
|
862
|
+
# si im est déjà une liste de (C, N_b)
|
|
863
|
+
xb = im[b]
|
|
864
|
+
yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
865
|
+
outs.append(yb.squeeze(0))
|
|
866
|
+
outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
|
|
867
|
+
return outs, outs_ids
|
|
868
|
+
|
|
869
|
+
# grille commune (un seul vecteur d'ids)
|
|
870
|
+
cid = self._to_numpy_1d(cell_ids)
|
|
871
|
+
return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
|
|
872
|
+
|
|
873
|
+
def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
|
|
874
|
+
"""
|
|
875
|
+
If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
|
|
876
|
+
If they are lists (var-length per sample) -> return list[Tensor].
|
|
877
|
+
"""
|
|
903
878
|
if self.f is None:
|
|
904
879
|
if self.dtype==torch.float64:
|
|
905
880
|
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
906
881
|
else:
|
|
907
882
|
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
908
|
-
|
|
883
|
+
|
|
909
884
|
if cell_ids is None:
|
|
910
885
|
dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
|
|
911
886
|
return dim
|
|
912
|
-
else:
|
|
913
|
-
if nside is None:
|
|
914
|
-
nside=self.nside
|
|
915
|
-
if nside is None:
|
|
916
|
-
nside=self.nside
|
|
917
|
-
if len(cell_ids.shape)==1:
|
|
918
|
-
return self.f.up_grade(im,nside*2,cell_ids=cell_ids,nside=nside,o_cell_ids=o_cell_ids)
|
|
919
|
-
else:
|
|
920
|
-
assert im.shape[0] == cell_ids.shape[0], \
|
|
921
|
-
f"cell_ids and data should have the same batch size (first column), got data={im.shape},cell_ids={cell_ids.shape}"
|
|
922
|
-
|
|
923
|
-
assert im.shape[0] == o_cell_ids.shape[0], \
|
|
924
|
-
f"cell_ids and data should have the same batch size (first column), got data={im.shape},o_cell_ids={o_cell_ids.shape}"
|
|
925
|
-
|
|
926
|
-
result = []
|
|
927
|
-
|
|
928
|
-
for k in range(im.shape[0]):
|
|
929
|
-
r= self.f.up_grade(im[k],nside*2,cell_ids=cell_ids[k],nside=nside,o_cell_ids=o_cell_ids[k])
|
|
930
|
-
result.append(r)
|
|
931
|
-
|
|
932
|
-
result = torch.stack(result, dim=0) # (B,...,Npix)
|
|
933
|
-
return result
|
|
934
887
|
|
|
888
|
+
if nside is None:
|
|
889
|
+
nside = self.nside
|
|
890
|
+
|
|
891
|
+
# var-length: listes parallèles
|
|
892
|
+
if self._is_varlength_batch(cell_ids):
|
|
893
|
+
assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
|
|
894
|
+
"In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
|
|
895
|
+
outs = []
|
|
896
|
+
B = len(cell_ids)
|
|
897
|
+
for b in range(B):
|
|
898
|
+
cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
|
|
899
|
+
ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
|
|
900
|
+
if torch.is_tensor(im):
|
|
901
|
+
xb = im[b:b+1] # (1, C, N_b_coarse)
|
|
902
|
+
yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
|
|
903
|
+
o_cell_ids=ocid_b, force_init_index=True)
|
|
904
|
+
outs.append(yb.squeeze(0)) # (C, N_b_fine)
|
|
905
|
+
else:
|
|
906
|
+
xb = im[b] # (C, N_b_coarse)
|
|
907
|
+
yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
|
|
908
|
+
o_cell_ids=ocid_b, force_init_index=True)
|
|
909
|
+
outs.append(yb.squeeze(0))
|
|
910
|
+
return outs
|
|
911
|
+
|
|
912
|
+
# grille commune
|
|
913
|
+
cid = self._to_numpy_1d(cell_ids)
|
|
914
|
+
ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
|
|
915
|
+
return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
|
|
916
|
+
o_cell_ids=ocid, force_init_index=True)
|
|
917
|
+
|
|
935
918
|
def to_tensor(self,x):
|
|
936
919
|
if self.f is None:
|
|
937
920
|
if self.dtype==torch.float64:
|
foscat/HealBili.py
CHANGED
|
@@ -37,6 +37,7 @@ All angles must be in **radians**. theta is colatitude (0 at north pole), phi is
|
|
|
37
37
|
from __future__ import annotations
|
|
38
38
|
|
|
39
39
|
from typing import Tuple
|
|
40
|
+
import healpy as hp
|
|
40
41
|
import numpy as np
|
|
41
42
|
|
|
42
43
|
try:
|
|
@@ -79,16 +80,16 @@ class HealBili:
|
|
|
79
80
|
# Public API
|
|
80
81
|
# -----------------------------
|
|
81
82
|
def compute_weights(
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
83
|
+
self,
|
|
84
|
+
level,
|
|
85
|
+
cell_ids: np.ndarray,
|
|
85
86
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
86
87
|
"""Compute bilinear weights/indices for target HEALPix angles.
|
|
87
88
|
|
|
88
89
|
Parameters
|
|
89
90
|
----------
|
|
90
|
-
|
|
91
|
-
Target **
|
|
91
|
+
cell_ids : np.ndarray, shape (N,)
|
|
92
|
+
Target **cell_ids** .
|
|
92
93
|
|
|
93
94
|
Returns
|
|
94
95
|
-------
|
|
@@ -98,14 +99,17 @@ class HealBili:
|
|
|
98
99
|
Bilinear weights aligned with `I`. Weights are set to 0.0 for invalid corners and normalized to sum to 1
|
|
99
100
|
when at least one corner is valid.
|
|
100
101
|
"""
|
|
102
|
+
#compute the coordinate of the selected cell_ids
|
|
103
|
+
heal_theta, heal_phi = hp.pix2ang(2**level,cell_ids,nest=True)
|
|
104
|
+
|
|
101
105
|
ht = np.asarray(heal_theta, dtype=float).ravel()
|
|
102
|
-
|
|
103
|
-
if ht.shape !=
|
|
106
|
+
hpt = np.asarray(heal_phi, dtype=float).ravel()
|
|
107
|
+
if ht.shape != hpt.shape:
|
|
104
108
|
raise ValueError("heal_theta and heal_phi must have the same 1D shape (N,)")
|
|
105
109
|
N = ht.size
|
|
106
110
|
|
|
107
111
|
# Target unit vectors
|
|
108
|
-
Vtgt = self._sph_to_vec(ht,
|
|
112
|
+
Vtgt = self._sph_to_vec(ht, hpt) # (N,3)
|
|
109
113
|
|
|
110
114
|
# 1) Choose a seed node for each target (nearest source grid node on the sphere)
|
|
111
115
|
seed_flat = self._nearest_source_indices(Vtgt)
|