foscat 2025.9.1__py3-none-any.whl → 2025.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/HOrientedConvol.py CHANGED
@@ -20,6 +20,7 @@ class HOrientedConvol:
20
20
  no_cell_ids=False,
21
21
  ):
22
22
 
23
+
23
24
  if dtype=='float64':
24
25
  self.dtype=torch.float64
25
26
  else:
@@ -72,7 +73,7 @@ class HOrientedConvol:
72
73
  t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
73
74
  else:
74
75
  t,p = hp.pix2ang(nside,idx_nn,nest=True)
75
-
76
+
76
77
  self.t=t[:,0]
77
78
  self.p=p[:,0]
78
79
  vec_orig=hp.ang2vec(t,p)
@@ -357,6 +358,47 @@ class HOrientedConvol:
357
358
 
358
359
  return csr_array((w, (indice_1_0, indice_1_1)), shape=(12*self.nside**2, 12*self.nside**2*NORIENT))
359
360
 
361
+ def make_idx_weights_from_cell_ids(self,
362
+ i_cell_ids,
363
+ polar=False,
364
+ gamma=1.0,
365
+ device='cuda',
366
+ allow_extrapolation=True):
367
+ """
368
+ Accept 1D (Npix,) or 2D (B, Npix) cell_ids and return
369
+ tensors batched sur la 1ère dim (B, ...).
370
+ """
371
+ # → cast numpy
372
+ if torch.is_tensor(i_cell_ids):
373
+ cid = i_cell_ids.detach().cpu().numpy()
374
+ else:
375
+ cid = np.asarray(i_cell_ids)
376
+
377
+ # --- 1D: pas de boucle, on calcule une fois, puis on ajoute l'axe batch
378
+ if cid.ndim == 1:
379
+ l_idx_nn, l_w_idx, l_w_w = self.make_idx_weights_from_one_cell_ids(
380
+ cid, polar=polar, gamma=gamma, device=device,
381
+ allow_extrapolation=allow_extrapolation
382
+ )
383
+ idx_nn = torch.as_tensor(l_idx_nn, device=device, dtype=torch.long)[None, ...] # (1, Npix, P)
384
+ w_idx = torch.as_tensor(l_w_idx, device=device, dtype=torch.long)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
385
+ w_w = torch.as_tensor(l_w_w, device=device, dtype=self.dtype)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
386
+ return idx_nn, w_idx, w_w
387
+
388
+ # --- 2D: boucle sur b, empilement en (B, ...)
389
+ elif cid.ndim == 2:
390
+ outs = [ self.make_idx_weights_from_one_cell_ids(
391
+ cid[k], polar=polar, gamma=gamma, device=device,
392
+ allow_extrapolation=allow_extrapolation)
393
+ for k in range(cid.shape[0]) ]
394
+ idx_nn = torch.as_tensor(np.stack([o[0] for o in outs], axis=0), device=device, dtype=torch.long)
395
+ w_idx = torch.as_tensor(np.stack([o[1] for o in outs], axis=0), device=device, dtype=torch.long)
396
+ w_w = torch.as_tensor(np.stack([o[2] for o in outs], axis=0), device=device, dtype=self.dtype)
397
+ return idx_nn, w_idx, w_w
398
+
399
+ else:
400
+ raise ValueError(f"Unsupported cell_ids ndim={cid.ndim}; expected 1 or 2.")
401
+ '''
360
402
  def make_idx_weights_from_cell_ids(self,i_cell_ids,
361
403
  polar=False,
362
404
  gamma=1.0,
@@ -387,7 +429,8 @@ class HOrientedConvol:
387
429
  w_w = torch.Tensor(np.stack(w_w,0)).to(device=device, dtype=self.dtype)
388
430
 
389
431
  return idx_nn,w_idx,w_w
390
-
432
+ '''
433
+
391
434
  def make_idx_weights_from_one_cell_ids(self,
392
435
  cell_ids,
393
436
  polar=False,
@@ -428,11 +471,34 @@ class HOrientedConvol:
428
471
  w_idx,w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
429
472
  yy*self.nside*gamma,
430
473
  allow_extrapolation=allow_extrapolation)
474
+ '''
475
+ # calib : [Npix, K]
476
+ calib = np.zeros((w_idx.shape[0], w_idx.shape[2]))
477
+ # Hypothèses :
478
+ # w_idx.shape == (Npix, M, K) et w_w.shape == (Npix, M, K)
479
+ Npix, M, K = w_idx.shape
480
+ nb_cols = K
481
+
482
+ # 1) Accumulation par "bincount" avec décalage de ligne
483
+ row_ids = np.arange(Npix, dtype=np.int64)[:, None, None] * nb_cols
484
+ flat_idx = (row_ids + w_idx).ravel() # indices dans [0, Npix*9)
485
+ weights = w_w.ravel().astype(np.float64) # ou dtype de ton choix
486
+
487
+ calib = np.bincount(flat_idx, weights, minlength=Npix*nb_cols)\
488
+ .reshape(Npix, nb_cols)
489
+
490
+ # 2) Réinjection dans norm_a selon w_idx
491
+ norm_a = calib[np.arange(Npix)[:, None, None], w_idx]
492
+
493
+ w_w /= norm_a
494
+ w_w = np.clip(w_w,0.0,1.0)
431
495
 
432
- del xx
433
- del yy
496
+ w_w[np.isnan(w_w)]=0.0
497
+ '''
498
+ #del xx
499
+ #del yy
434
500
 
435
- return idx_nn,w_idx,w_w
501
+ return idx_nn,w_idx,w_w,xx,yy
436
502
 
437
503
  def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True,return_index=False):
438
504
 
@@ -547,27 +613,58 @@ class HOrientedConvol:
547
613
  idx = np.stack([i00, i10, i01, i11], axis=1).astype(np.int64)
548
614
 
549
615
  return idx, w
616
+
617
+ # --- Add inside class HOrientedConvol, just above Convol_torch ---
618
+ def _convol_single(self, im1: torch.Tensor, ww: torch.Tensor, cell_ids=None, nside=None):
619
+ """
620
+ Single-sample path. im1: (1, C_i, Npix_1). Returns (1, C_o, Npix_1).
621
+ """
622
+ if not isinstance(im1, torch.Tensor):
623
+ im1 = torch.as_tensor(im1, device=self.device, dtype=self.dtype)
624
+ if not isinstance(ww, torch.Tensor):
625
+ ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
626
+ assert im1.ndim == 3 and im1.shape[0] == 1, f"expected (1, C_i, Npix), got {tuple(im1.shape)}"
627
+
628
+ # Reuse the existing Convol_torch core by faking B=1 shapes.
629
+ # We call the existing (batched) implementation with B=1.
630
+ return self.Convol_torch(im1, ww, cell_ids=cell_ids, nside=nside) # returns (1, C_o, Npix_1)
550
631
 
632
+ # --- Replace the first lines of Convol_torch with a dispatcher ---
551
633
  def Convol_torch(self, im, ww, cell_ids=None, nside=None):
552
634
  """
553
- Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
635
+ Batched KERNELSZxKERNELSZ aggregation.
554
636
 
555
- Parameters
556
- ----------
557
- im : Tensor, shape (B, C_i, Npix)
558
- ww : Tensor, shapes supported:
559
- (C_i, C_o, M) | (C_i, C_o, M, S) | (B, C_i, C_o, M) | (B, C_i, C_o, M, S)
560
- cell_ids : ndarray or Tensor
561
- - None: use precomputed self.idx_nn / self.w_idx / self.w_w (shared for batch).
562
- - (Npix,): recompute once (shared for batch).
563
- - (B, Npix): recompute per-sample (different for each b).
564
-
565
- Returns
566
- -------
567
- out : Tensor, shape (B, C_o, Npix)
637
+ Accepts either:
638
+ - im: Tensor (B, C_i, Npix) with one shared or per-batch (B,Npix) cell_ids
639
+ - im: list/tuple of Tensors, each (C_i, Npix_b), with cell_ids a list of arrays
568
640
  """
569
641
  import torch
570
642
 
643
+ # (A) Variable-length per-sample path: im is a list/tuple OR cell_ids is a list/tuple
644
+ if isinstance(im, (list, tuple)) or isinstance(cell_ids, (list, tuple)):
645
+ # Normalize to lists
646
+ im_list = im if isinstance(im, (list, tuple)) else [im]
647
+ cid_list = cell_ids if isinstance(cell_ids, (list, tuple)) else [cell_ids] * len(im_list)
648
+ assert len(im_list) == len(cid_list), "im list and cell_ids list must have same length"
649
+
650
+ outs = []
651
+ for xb, cb in zip(im_list, cid_list):
652
+ # xb: (C_i, Npix_b) -> (1, C_i, Npix_b)
653
+ if not torch.is_tensor(xb):
654
+ xb = torch.as_tensor(xb, device=self.device, dtype=self.dtype)
655
+ if xb.dim() == 2:
656
+ xb = xb.unsqueeze(0)
657
+ elif xb.dim() != 3 or xb.shape[0] != 1:
658
+ raise ValueError(f"Each sample must be (C,N) or (1,C,N); got {tuple(xb.shape)}")
659
+
660
+ yb = self._convol_single(xb, ww, cell_ids=cb, nside=nside) # (1, C_o, Npix_b)
661
+ outs.append(yb.squeeze(0)) # -> (C_o, Npix_b)
662
+ return outs # List[Tensor], each (C_o, Npix_b)
663
+
664
+ # (B) Standard fixed-length batched path (your current implementation)
665
+ # ... keep your existing Convol_torch body from here unchanged ...
666
+ # (paste your current function body starting from the type casting and assertions)
667
+
571
668
  # ---- Basic checks / casting ----
572
669
  if not isinstance(im, torch.Tensor):
573
670
  im = torch.as_tensor(im, device=self.device, dtype=self.dtype)
@@ -585,55 +682,37 @@ class HOrientedConvol:
585
682
  # w_idx_eff : (B, Npix, S, P)
586
683
  # w_w_eff : (B, Npix, S, P)
587
684
  if cell_ids is not None:
588
- # to numpy for your make_idx_weights_from_cell_ids helper if needed
589
- if isinstance(cell_ids, torch.Tensor):
590
- cid = cell_ids.detach().to("cpu").numpy()
591
- else:
592
- cid = cell_ids
593
-
594
- if cid.ndim == 1:
595
- # single set of ids for the whole batch
596
- idx_nn, w_idx, w_w = self.make_idx_weights_from_cell_ids(cid, nside, device=device)
597
- assert idx_nn.ndim == 2, "idx_nn expected (Npix,P)"
598
- P = idx_nn.shape[1]
599
- if w_idx.ndim == 2:
600
- # (Npix,P) -> (B,Npix,1,P)
601
- S = 1
602
- w_idx_eff = w_idx[None, :, None, :].expand(B, -1, -1, -1)
603
- w_w_eff = w_w[None, :, None, :].expand(B, -1, -1, -1)
604
- elif w_idx.ndim == 3:
605
- # (Npix,S,P) -> (B,Npix,S,P)
606
- S = w_idx.shape[1]
607
- w_idx_eff = w_idx[None, ...].expand(B, -1, -1, -1)
608
- w_w_eff = w_w[None, ...].expand(B, -1, -1, -1)
685
+ # ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
686
+ # Normaliser: accepter Tensor, ndarray ou list/tuple de 1 élément (cas var-length, B=1)
687
+ if isinstance(cell_ids, (list, tuple)):
688
+ # liste d'ids (souvent longueur 1 en var-length)
689
+ if len(cell_ids) == 1:
690
+ cid = np.asarray(cell_ids[0])[None, :] # -> (1, Npix)
609
691
  else:
610
- raise ValueError(f"Unsupported w_idx shape {tuple(w_idx.shape)}")
611
- idx_nn_eff = idx_nn[None, ...].expand(B, -1, -1) # (B,Npix,P)
612
-
613
- elif cid.ndim == 2:
614
- # per-sample ids
615
- assert cid.shape[0] == B and cid.shape[1] == Npix, \
616
- f"cell_ids must be (B,Npix) with B={B},Npix={Npix}, got {cid.shape}"
617
- S_ref = None
618
-
619
- idx_nn_eff, w_idx_eff, w_w_eff = self.make_idx_weights_from_cell_ids(cid,
620
- nside,
621
- device=device)
622
- P = idx_nn_eff.shape[-1]
623
- S = w_idx_eff.shape[-2]
624
-
692
+ # si jamais >1, on essaie d'empiler (doit avoir même Npix par élément)
693
+ cid = np.stack([np.asarray(c) for c in cell_ids], axis=0)
694
+ elif torch.is_tensor(cell_ids):
695
+ c = cell_ids.detach().cpu().numpy()
696
+ cid = c if c.ndim != 1 else c[None, :] # uniformiser en 2D quand B=1
625
697
  else:
626
- raise ValueError(f"Unsupported cell_ids shape {cid.shape}")
698
+ c = np.asarray(cell_ids)
699
+ cid = c if c.ndim != 1 else c[None, :]
627
700
 
628
- # ensure tensors on right device/dtype
701
+ # cid est maintenant (B, Npix)
702
+ idx_nn_eff, w_idx_eff, w_w_eff = self.make_idx_weights_from_cell_ids(
703
+ cid, nside, device=device
704
+ )
705
+ # shapes: (B, Npix, P), (B, Npix, S, P|P), (B, Npix, S, P|P)
706
+ P = idx_nn_eff.shape[-1]
707
+ S = w_idx_eff.shape[-2] if w_idx_eff.ndim == 4 else 1
708
+
709
+ # s’assurer des dtypes/devices
629
710
  idx_nn_eff = torch.as_tensor(idx_nn_eff, device=device, dtype=torch.long)
630
711
  w_idx_eff = torch.as_tensor(w_idx_eff, device=device, dtype=torch.long)
631
712
  w_w_eff = torch.as_tensor(w_w_eff, device=device, dtype=dtype)
632
-
633
713
  else:
634
714
  # Use precomputed (shared for batch)
635
715
  if self.w_idx is None:
636
-
637
716
  if self.cell_ids.ndim==1:
638
717
  l_cell=self.cell_ids[None,:]
639
718
  else:
@@ -724,214 +803,118 @@ class HOrientedConvol:
724
803
  out = out_ci.sum(dim=1) # sum over C_i -> (B, C_o, Npix)
725
804
 
726
805
  return out
727
-
728
- def Convol_torch_old(self, im, ww,cell_ids=None,nside=None):
729
- """
730
- Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
731
806
 
732
- Parameters
733
- ----------
734
- im : Tensor, shape (B, C_i, Npix)
735
- Input features per pixel for a batch of B samples.
736
- ww : Tensor
737
- Base mixing weights, indexed along its 'M' dimension by self.w_idx.
738
- Supported shapes:
739
- (C_i, C_o, M)
740
- (C_i, C_o, M, S)
741
- (B, C_i, C_o, M)
742
- (B, C_i, C_o, M, S)
743
-
744
- cell_ids : ndarray
745
- If cell_ids is not None recompute the index and do not use the precomputed ones.
746
- Note : The computation is then much longer.
747
-
748
- Class members (already tensors; will be aligned to im.device/dtype):
749
- -------------------------------------------------------------------
750
- self.idx_nn : LongTensor, shape (Npix, P)
751
- For each center pixel, the P neighbor indices into the Npix axis of `im`.
752
- (P = K*K for a KxK neighborhood.)
753
- self.w_idx : LongTensor, shape (Npix, P) or (Npix, S, P)
754
- Indices along the 'M' dimension of ww, per (center[, sector], neighbor).
755
- self.w_w : Tensor, shape (Npix, P) or (Npix, S, P)
756
- Additional scalar weights per neighbor (same layout as w_idx).
757
-
758
- Returns
759
- -------
760
- out : Tensor, shape (B, C_o, Npix)
761
- Aggregated output per center pixel for each batch sample.
807
+ def _to_numpy_1d(self, ids):
808
+ """Return a 1D numpy array of int64 for a single set of cell ids."""
809
+ import numpy as np, torch
810
+ if isinstance(ids, np.ndarray):
811
+ return ids.reshape(-1).astype(np.int64, copy=False)
812
+ if torch.is_tensor(ids):
813
+ return ids.detach().cpu().to(torch.long).view(-1).numpy()
814
+ # python list/tuple of ints
815
+ return np.asarray(ids, dtype=np.int64).reshape(-1)
816
+
817
+ def _is_varlength_batch(self, ids):
818
+ """
819
+ True if ids is a list/tuple of per-sample id arrays (var-length batch).
820
+ False if ids is a single array/tensor of ids (shared for whole batch).
821
+ """
822
+ import numpy as np, torch
823
+ if isinstance(ids, (list, tuple)):
824
+ return True
825
+ if isinstance(ids, np.ndarray) and ids.ndim == 2:
826
+ # This would be a dense (B, Npix) matrix -> NOT var-length list
827
+ return False
828
+ if torch.is_tensor(ids) and ids.dim() == 2:
829
+ return False
830
+ return False
831
+
832
+ def Down(self, im, cell_ids=None, nside=None,max_poll=False):
833
+ """
834
+ If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
835
+ If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
762
836
  """
763
- # ---- Basic checks ----
764
- if not isinstance(im,torch.Tensor):
765
- im=torch.Tensor(im).to(device=self.device, dtype=self.dtype)
766
- if not isinstance(ww,torch.Tensor):
767
- ww=torch.Tensor(ww).to(device=self.device, dtype=self.dtype)
768
-
769
- assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
770
-
771
- assert ww.shape[2]==self.KERNELSZ*self.KERNELSZ, f"`ww` must be (C_i, C_o, KERNELSZ*KERNELSZ), got {tuple(ww.shape)}"
772
-
773
- B, C_i, Npix = im.shape
774
- device = im.device
775
- dtype = im.dtype
776
-
777
- if cell_ids is not None:
778
-
779
- idx_nn,w_idx,w_w = self.make_idx_weights_from_cell_ids(cell_ids,nside,device=device)
780
- else:
781
- idx_nn = self.idx_nn # (Npix, P)
782
- w_idx = self.w_idx # (Npix, P) or (Npix, S, P)
783
- w_w = self.w_w # (Npix, P) or (Npix, S, P)
784
-
785
- # Neighbor count P inferred from idx_nn
786
- assert idx_nn.ndim == 2 and idx_nn.size(0) == Npix, \
787
- f"`idx_nn` must be (Npix, P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
788
- P = idx_nn.size(1)
789
-
790
- # ---- 1) Gather neighbor values from im along the Npix dimension -> (B, C_i, Npix, P)
791
- # im: (B,C_i,Npix) -> (B,C_i,Npix,1); idx: (1,1,Npix,P) broadcast over (B,C_i)
792
- rim = torch.take_along_dim(
793
- im.unsqueeze(-1),
794
- idx_nn.unsqueeze(0).unsqueeze(0),
795
- dim=2
796
- ) # (B, C_i, Npix, P)
797
-
798
- # ---- 2) Normalize w_idx / w_w to include a sector dim S ----
799
- # Target layout: (Npix, S, P)
800
- if w_idx.ndim == 2:
801
- # (Npix, P) -> add sector dim S=1
802
- assert w_idx.size(0) == Npix and w_idx.size(1) == P
803
- w_idx_eff = w_idx.unsqueeze(1) # (Npix, 1, P)
804
- w_w_eff = w_w.unsqueeze(1) # (Npix, 1, P)
805
- S = 1
806
- elif w_idx.ndim == 3:
807
- # (Npix, S, P)
808
- Npix_, S, P_ = w_idx.shape
809
- assert Npix_ == Npix and P_ == P, \
810
- f"`w_idx` must be (Npix,S,P) with Npix={Npix}, P={P}, got {tuple(w_idx.shape)}"
811
- assert w_w.shape == w_idx.shape, "`w_w` must match `w_idx` shape"
812
- w_idx_eff = w_idx
813
- w_w_eff = w_w
814
- else:
815
- raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
816
-
817
- # ---- 3) Normalize ww to (B, C_i, C_o, M, S) for uniform gather ----
818
- if ww.ndim == 3:
819
- # (C_i, C_o, M) -> (B, C_i, C_o, M, S)
820
- C_i_w, C_o, M = ww.shape
821
- assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
822
- ww_eff = ww.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, -1, S)
823
-
824
- elif ww.ndim == 4:
825
- # Could be (C_i, C_o, M, S) or (B, C_i, C_o, M)
826
- if ww.shape[0] == C_i and ww.shape[1] != C_i:
827
- # (C_i, C_o, M, S) -> (B, C_i, C_o, M, S)
828
- C_i_w, C_o, M, S_w = ww.shape
829
- assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
830
- assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
831
- ww_eff = ww.unsqueeze(0).expand(B, -1, -1, -1, -1)
832
- elif ww.shape[0] == B:
833
- # (B, C_i, C_o, M) -> (B, C_i, C_o, M, S)
834
- _, C_i_w, C_o, M = ww.shape
835
- assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
836
- ww_eff = ww.unsqueeze(-1).expand(-1, -1, -1, -1, S)
837
- else:
838
- raise ValueError(
839
- f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)"
840
- )
841
-
842
- elif ww.ndim == 5:
843
- # (B, C_i, C_o, M, S)
844
- assert ww.shape[0] == B and ww.shape[1] == C_i, "ww batch/C_i mismatch"
845
- _, _, _, M, S_w = ww.shape
846
- assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
847
- ww_eff = ww
848
- else:
849
- raise ValueError(f"Unsupported ww shape {tuple(ww.shape)}")
850
-
851
- # ---- 4) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
852
- idx_exp = w_idx_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1,1,1,Npix,S,P)
853
- rw = torch.take_along_dim(
854
- ww_eff.unsqueeze(-1), # (B, C_i, C_o, M, S, 1)
855
- idx_exp, # (1,1,1,Npix,S,P) -> broadcast
856
- dim=3 # gather along M
857
- ) # -> (B, C_i, C_o, Npix, S, P)
858
-
859
- # ---- 5) Apply extra neighbor weights ----
860
- rw = rw * w_w_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (B, C_i, C_o, Npix, S, P)
861
-
862
- # ---- 6) Combine neighbor values and weights ----
863
- # rim: (B, C_i, Npix, P) -> expand to (B, C_i, 1, Npix, 1, P)
864
- rim_exp = rim[:, :, None, :, None, :]
865
- # sum over neighbors (P), then over sectors (S), then over input channels (C_i)
866
- out_ci = (rim_exp * rw).sum(dim=-1) # (B, C_i, C_o, Npix, S)
867
- out_ci = out_ci.sum(dim=-1) # (B, C_i, C_o, Npix)
868
- out = out_ci.sum(dim=1) # (B, C_o, Npix)
869
-
870
- return out
871
-
872
- def Down(self, im, cell_ids=None,nside=None):
873
837
  if self.f is None:
874
838
  if self.dtype==torch.float64:
875
839
  self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
876
840
  else:
877
841
  self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
878
-
879
- if cell_ids is None:
880
- dim,_ = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside)
881
- return dim
882
- else:
883
- if nside is None:
884
- nside=self.nside
885
- if len(cell_ids.shape)==1:
886
- return self.f.ud_grade_2(im,cell_ids=cell_ids,nside=nside)
887
- else:
888
- assert im.shape[0] == cell_ids.shape[0], \
889
- f"cell_ids and data should have the same batch size (first column), got data={im.shape},cell_ids={cell_ids.shape}"
890
-
891
- result,result_cell_ids = [],[]
892
-
893
- for k in range(im.shape[0]):
894
- r,c = self.f.ud_grade_2(im[k],cell_ids=cell_ids[k],nside=nside)
895
- result.append(r)
896
- result_cell_ids.append(c)
897
-
898
- result = torch.stack(result, dim=0) # (B,...,Npix)
899
- result_cell_ids = torch.stack(result_cell_ids, dim=0) # (B,Npix)
900
- return result,result_cell_ids
901
842
 
902
- def Up(self, im, cell_ids=None,nside=None,o_cell_ids=None):
843
+ if cell_ids is None:
844
+ dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=False)
845
+ return dim,cdim
846
+
847
+ if nside is None:
848
+ nside = self.nside
849
+
850
+ # var-length mode: list/tuple of ids, one per sample
851
+ if self._is_varlength_batch(cell_ids):
852
+ outs, outs_ids = [], []
853
+ B = len(cell_ids)
854
+ for b in range(B):
855
+ cid_b = self._to_numpy_1d(cell_ids[b])
856
+ # extraire le bon échantillon d'`im`
857
+ if torch.is_tensor(im):
858
+ xb = im[b:b+1] # (1, C, N_b)
859
+ yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
860
+ outs.append(yb.squeeze(0)) # (C, N_b')
861
+ else:
862
+ # si im est déjà une liste de (C, N_b)
863
+ xb = im[b]
864
+ yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
865
+ outs.append(yb.squeeze(0))
866
+ outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
867
+ return outs, outs_ids
868
+
869
+ # grille commune (un seul vecteur d'ids)
870
+ cid = self._to_numpy_1d(cell_ids)
871
+ return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
872
+
873
+ def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
874
+ """
875
+ If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
876
+ If they are lists (var-length per sample) -> return list[Tensor].
877
+ """
903
878
  if self.f is None:
904
879
  if self.dtype==torch.float64:
905
880
  self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
906
881
  else:
907
882
  self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
908
-
883
+
909
884
  if cell_ids is None:
910
885
  dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
911
886
  return dim
912
- else:
913
- if nside is None:
914
- nside=self.nside
915
- if nside is None:
916
- nside=self.nside
917
- if len(cell_ids.shape)==1:
918
- return self.f.up_grade(im,nside*2,cell_ids=cell_ids,nside=nside,o_cell_ids=o_cell_ids)
919
- else:
920
- assert im.shape[0] == cell_ids.shape[0], \
921
- f"cell_ids and data should have the same batch size (first column), got data={im.shape},cell_ids={cell_ids.shape}"
922
-
923
- assert im.shape[0] == o_cell_ids.shape[0], \
924
- f"cell_ids and data should have the same batch size (first column), got data={im.shape},o_cell_ids={o_cell_ids.shape}"
925
-
926
- result = []
927
-
928
- for k in range(im.shape[0]):
929
- r= self.f.up_grade(im[k],nside*2,cell_ids=cell_ids[k],nside=nside,o_cell_ids=o_cell_ids[k])
930
- result.append(r)
931
-
932
- result = torch.stack(result, dim=0) # (B,...,Npix)
933
- return result
934
887
 
888
+ if nside is None:
889
+ nside = self.nside
890
+
891
+ # var-length: listes parallèles
892
+ if self._is_varlength_batch(cell_ids):
893
+ assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
894
+ "In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
895
+ outs = []
896
+ B = len(cell_ids)
897
+ for b in range(B):
898
+ cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
899
+ ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
900
+ if torch.is_tensor(im):
901
+ xb = im[b:b+1] # (1, C, N_b_coarse)
902
+ yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
903
+ o_cell_ids=ocid_b, force_init_index=True)
904
+ outs.append(yb.squeeze(0)) # (C, N_b_fine)
905
+ else:
906
+ xb = im[b] # (C, N_b_coarse)
907
+ yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
908
+ o_cell_ids=ocid_b, force_init_index=True)
909
+ outs.append(yb.squeeze(0))
910
+ return outs
911
+
912
+ # grille commune
913
+ cid = self._to_numpy_1d(cell_ids)
914
+ ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
915
+ return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
916
+ o_cell_ids=ocid, force_init_index=True)
917
+
935
918
  def to_tensor(self,x):
936
919
  if self.f is None:
937
920
  if self.dtype==torch.float64:
foscat/HealBili.py CHANGED
@@ -37,6 +37,7 @@ All angles must be in **radians**. theta is colatitude (0 at north pole), phi is
37
37
  from __future__ import annotations
38
38
 
39
39
  from typing import Tuple
40
+ import healpy as hp
40
41
  import numpy as np
41
42
 
42
43
  try:
@@ -79,16 +80,16 @@ class HealBili:
79
80
  # Public API
80
81
  # -----------------------------
81
82
  def compute_weights(
82
- self,
83
- heal_theta: np.ndarray,
84
- heal_phi: np.ndarray,
83
+ self,
84
+ level,
85
+ cell_ids: np.ndarray,
85
86
  ) -> Tuple[np.ndarray, np.ndarray]:
86
87
  """Compute bilinear weights/indices for target HEALPix angles.
87
88
 
88
89
  Parameters
89
90
  ----------
90
- heal_theta, heal_phi : np.ndarray, shape (N,)
91
- Target **colatitude** and **longitude** in radians.
91
+ cell_ids : np.ndarray, shape (N,)
92
+ Target **cell_ids** .
92
93
 
93
94
  Returns
94
95
  -------
@@ -98,14 +99,17 @@ class HealBili:
98
99
  Bilinear weights aligned with `I`. Weights are set to 0.0 for invalid corners and normalized to sum to 1
99
100
  when at least one corner is valid.
100
101
  """
102
+ #compute the coordinate of the selected cell_ids
103
+ heal_theta, heal_phi = hp.pix2ang(2**level,cell_ids,nest=True)
104
+
101
105
  ht = np.asarray(heal_theta, dtype=float).ravel()
102
- hp = np.asarray(heal_phi, dtype=float).ravel()
103
- if ht.shape != hp.shape:
106
+ hpt = np.asarray(heal_phi, dtype=float).ravel()
107
+ if ht.shape != hpt.shape:
104
108
  raise ValueError("heal_theta and heal_phi must have the same 1D shape (N,)")
105
109
  N = ht.size
106
110
 
107
111
  # Target unit vectors
108
- Vtgt = self._sph_to_vec(ht, hp) # (N,3)
112
+ Vtgt = self._sph_to_vec(ht, hpt) # (N,3)
109
113
 
110
114
  # 1) Choose a seed node for each target (nearest source grid node on the sphere)
111
115
  seed_flat = self._nearest_source_indices(Vtgt)
foscat/Plot.py CHANGED
@@ -7,10 +7,10 @@ def lgnomproject(
7
7
  data, # array-like (N,), values per cell id
8
8
  nside: int,
9
9
  rot=None, # (lon0_deg, lat0_deg, psi_deg). If None: auto-center from cell_ids (pix centers)
10
- xsize: int = 800,
11
- ysize: int = 800,
10
+ xsize: int = 400,
11
+ ysize: int = 400,
12
12
  reso: float = None, # deg/pixel on tangent plane; if None, use fov_deg
13
- fov_deg=10.0, # full FoV deg (scalar or (fx,fy))
13
+ fov_deg=None, # full FoV deg (scalar or (fx,fy))
14
14
  nest: bool = True, # True if your cell_ids are NESTED (and ang2pix to be done in NEST)
15
15
  reduce: str = "mean", # 'mean'|'median'|'sum'|'first' when duplicates in cell_ids
16
16
  mask_outside: bool = True,
@@ -118,6 +118,9 @@ def lgnomproject(
118
118
  half_x = 0.5 * xsize * dx
119
119
  half_y = 0.5 * ysize * dy
120
120
  else:
121
+ if fov_deg is None:
122
+ fov_deg=np.rad2deg(np.sqrt(cell_ids.shape[0])/nside)*1.4
123
+
121
124
  if np.isscalar(fov_deg):
122
125
  fx, fy = float(fov_deg), float(fov_deg)
123
126
  else:
@@ -176,7 +179,7 @@ def lgnomproject(
176
179
  # Axes in approx. "gnomonic degrees" (atan of plane coords)
177
180
  x_deg = np.degrees(np.arctan(xs))
178
181
  y_deg = np.degrees(np.arctan(ys))
179
-
182
+
180
183
  longitude_min=x_deg[0]/np.cos(np.deg2rad(lat0_deg))+lon0_deg
181
184
  longitude_max=x_deg[-1]/np.cos(np.deg2rad(lat0_deg))+lon0_deg
182
185
 
@@ -199,7 +202,7 @@ def lgnomproject(
199
202
  cmap=cmap,
200
203
  vmin=vmin, vmax=vmax,
201
204
  interpolation="nearest",
202
- aspect="equal"
205
+ aspect="auto"
203
206
  )
204
207
  if not notext:
205
208
  ax.set_xlabel("Longitude (deg)")
@@ -216,7 +219,7 @@ def lgnomproject(
216
219
  cb = fig.colorbar(im, ax=ax)
217
220
  cb.set_label("value")
218
221
  else:
219
- plt.colorbar(im, ax=ax, orientation="vertical", label="value")
222
+ plt.colorbar(im, ax=ax, orientation="horizontal", label="value")
220
223
 
221
224
  plt.tight_layout()
222
225
  if hold: