foscat 2025.8.4__py3-none-any.whl → 2025.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/HOrientedConvol.py CHANGED
@@ -3,16 +3,37 @@ import matplotlib.pyplot as plt
3
3
  import healpy as hp
4
4
  from scipy.sparse import csr_array
5
5
  import torch
6
+ import foscat.scat_cov as sc
6
7
  from scipy.spatial import cKDTree
7
8
 
8
9
  class HOrientedConvol:
9
- def __init__(self,nside,KERNELSZ,cell_ids=None,nest=True):
10
+ def __init__(self,
11
+ nside,
12
+ KERNELSZ,
13
+ cell_ids=None,
14
+ nest=True,
15
+ device='cuda',
16
+ dtype='float64',
17
+ polar=False,
18
+ gamma=1.0,
19
+ allow_extrapolation=True,
20
+ no_cell_ids=False,
21
+ ):
22
+
23
+
24
+ if dtype=='float64':
25
+ self.dtype=torch.float64
26
+ else:
27
+ self.dtype=torch.float32
10
28
 
11
29
  if KERNELSZ % 2 == 0:
12
30
  raise ValueError(f"N must be odd so that coordinates are integers from -K..K; got N={KERNELSZ}.")
13
31
 
14
32
  self.local_test=False
15
-
33
+
34
+ if no_cell_ids==True:
35
+ cell_ids=np.arange(10)
36
+
16
37
  if cell_ids is None:
17
38
  self.cell_ids=np.arange(12*nside**2)
18
39
 
@@ -28,37 +49,84 @@ class HOrientedConvol:
28
49
  self.cell_ids=cell_ids
29
50
 
30
51
  self.local_test=True
31
-
32
- idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
33
- KERNELSZ*KERNELSZ,
34
- nside,
35
- nest=nest,
36
- )
37
52
 
53
+ if self.cell_ids.ndim==1:
54
+ idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
55
+ KERNELSZ*KERNELSZ,
56
+ nside,
57
+ nest=nest,
58
+ )
59
+ else:
60
+ idx_nn = []
61
+ for k in range(self.cell_ids.shape[0]):
62
+ idx_nn.append(self.knn_healpix_ckdtree(self.cell_ids[k],
63
+ KERNELSZ*KERNELSZ,
64
+ nside,
65
+ nest=nest,
66
+ ))
67
+ idx_nn=np.stack(idx_nn,0)
68
+
69
+ if self.cell_ids.ndim==1:
70
+ mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids,nest=nest)
38
71
 
39
- mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids,nest=nest)
72
+ if self.local_test:
73
+ t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
74
+ else:
75
+ t,p = hp.pix2ang(nside,idx_nn,nest=True)
76
+
77
+ self.t=t[:,0]
78
+ self.p=p[:,0]
79
+ vec_orig=hp.ang2vec(t,p)
40
80
 
41
- if self.local_test:
42
- t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
81
+ self.vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
82
+
83
+ '''
84
+ if self.local_test:
85
+ idx_nn=self.remap_by_first_column(idx_nn)
86
+ '''
87
+
88
+ del mat_pt
89
+ del vec_orig
43
90
  else:
44
- t,p = hp.pix2ang(nside,idx_nn,nest=True)
45
91
 
46
- vec_orig=hp.ang2vec(t,p)
92
+ t,p,vec_rot = [],[],[]
93
+
94
+ for k in range(self.cell_ids.shape[0]):
95
+ mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids[k],nest=nest)
96
+
97
+ lt,lp = hp.pix2ang(nside,self.cell_ids[k,idx_nn[k]],nest=True)
98
+
99
+ vec_orig=hp.ang2vec(lt,lp)
100
+
101
+ l_vec_rot=np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
102
+ vec_rot.append(l_vec_rot)
103
+
104
+ del vec_orig
105
+ del mat_pt
106
+
107
+ t.append(lt[:,0])
108
+ p.append(lp[:,0])
47
109
 
48
- self.vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
110
+
111
+ self.t=np.stack(t,0)
112
+ self.p=np.stack(p,0)
113
+ self.vec_rot=np.stack(vec_rot,0)
49
114
 
50
- '''
51
- if self.local_test:
52
- idx_nn=self.remap_by_first_column(idx_nn)
53
- '''
115
+ del t
116
+ del p
117
+ del vec_rot
118
+
119
+ self.polar=polar
120
+ self.gamma=gamma
121
+ self.device=device
122
+ self.allow_extrapolation=allow_extrapolation
123
+ self.w_idx=None
54
124
 
55
- del mat_pt
56
- del vec_orig
57
- self.t=t[:,0]
58
- self.p=p[:,0]
59
125
  self.idx_nn=idx_nn
60
126
  self.nside=nside
61
127
  self.KERNELSZ=KERNELSZ
128
+ self.nest=nest
129
+ self.f=None
62
130
 
63
131
  def remap_by_first_column(self,idx: np.ndarray) -> np.ndarray:
64
132
  """
@@ -290,25 +358,160 @@ class HOrientedConvol:
290
358
 
291
359
  return csr_array((w, (indice_1_0, indice_1_1)), shape=(12*self.nside**2, 12*self.nside**2*NORIENT))
292
360
 
293
-
294
- def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True):
361
+ def make_idx_weights_from_cell_ids(self,
362
+ i_cell_ids,
363
+ polar=False,
364
+ gamma=1.0,
365
+ device='cuda',
366
+ allow_extrapolation=True):
367
+ """
368
+ Accept 1D (Npix,) or 2D (B, Npix) cell_ids and return
369
+ tensors batched sur la 1ère dim (B, ...).
370
+ """
371
+ # → cast numpy
372
+ if torch.is_tensor(i_cell_ids):
373
+ cid = i_cell_ids.detach().cpu().numpy()
374
+ else:
375
+ cid = np.asarray(i_cell_ids)
376
+
377
+ # --- 1D: pas de boucle, on calcule une fois, puis on ajoute l'axe batch
378
+ if cid.ndim == 1:
379
+ l_idx_nn, l_w_idx, l_w_w = self.make_idx_weights_from_one_cell_ids(
380
+ cid, polar=polar, gamma=gamma, device=device,
381
+ allow_extrapolation=allow_extrapolation
382
+ )
383
+ idx_nn = torch.as_tensor(l_idx_nn, device=device, dtype=torch.long)[None, ...] # (1, Npix, P)
384
+ w_idx = torch.as_tensor(l_w_idx, device=device, dtype=torch.long)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
385
+ w_w = torch.as_tensor(l_w_w, device=device, dtype=self.dtype)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
386
+ return idx_nn, w_idx, w_w
387
+
388
+ # --- 2D: boucle sur b, empilement en (B, ...)
389
+ elif cid.ndim == 2:
390
+ outs = [ self.make_idx_weights_from_one_cell_ids(
391
+ cid[k], polar=polar, gamma=gamma, device=device,
392
+ allow_extrapolation=allow_extrapolation)
393
+ for k in range(cid.shape[0]) ]
394
+ idx_nn = torch.as_tensor(np.stack([o[0] for o in outs], axis=0), device=device, dtype=torch.long)
395
+ w_idx = torch.as_tensor(np.stack([o[1] for o in outs], axis=0), device=device, dtype=torch.long)
396
+ w_w = torch.as_tensor(np.stack([o[2] for o in outs], axis=0), device=device, dtype=self.dtype)
397
+ return idx_nn, w_idx, w_w
398
+
399
+ else:
400
+ raise ValueError(f"Unsupported cell_ids ndim={cid.ndim}; expected 1 or 2.")
401
+ '''
402
+ def make_idx_weights_from_cell_ids(self,i_cell_ids,
403
+ polar=False,
404
+ gamma=1.0,
405
+ device='cuda',
406
+ allow_extrapolation=True):
407
+ if len(i_cell_ids.shape)<2:
408
+ cell_ids=i_cell_ids
409
+ n_cids=1
410
+ else:
411
+ cell_ids=i_cell_ids[0]
412
+ n_cids=i_cell_ids.shape[0]
413
+
414
+ idx_nn,w_idx,w_w = [],[],[]
415
+
416
+ for k in range(n_cids):
417
+ cell_ids=i_cell_ids[k]
418
+ l_idx_nn,l_w_idx,l_w_w = self.make_idx_weights_from_one_cell_ids(cell_ids,
419
+ polar=polar,
420
+ gamma=gamma,
421
+ device=device,
422
+ allow_extrapolation=allow_extrapolation)
423
+ idx_nn.append(l_idx_nn)
424
+ w_idx.append(l_w_idx)
425
+ w_w.append(l_w_w)
426
+
427
+ idx_nn = torch.Tensor(np.stack(idx_nn,0)).to(device=device, dtype=torch.long)
428
+ w_idx = torch.Tensor(np.stack(w_idx,0)).to(device=device, dtype=torch.long)
429
+ w_w = torch.Tensor(np.stack(w_w,0)).to(device=device, dtype=self.dtype)
430
+
431
+ return idx_nn,w_idx,w_w
432
+ '''
433
+
434
+ def make_idx_weights_from_one_cell_ids(self,
435
+ cell_ids,
436
+ polar=False,
437
+ gamma=1.0,
438
+ device='cuda',
439
+ allow_extrapolation=True):
440
+
441
+ idx_nn = self.knn_healpix_ckdtree(cell_ids,
442
+ self.KERNELSZ*self.KERNELSZ,
443
+ self.nside,
444
+ nest=self.nest,
445
+ )
446
+
447
+ mat_pt=self.rotation_matrices_from_healpix(self.nside,cell_ids,nest=self.nest)
448
+
449
+ t,p = hp.pix2ang(self.nside,cell_ids[idx_nn],nest=self.nest)
450
+
451
+ vec_orig=hp.ang2vec(t,p)
452
+
453
+ vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
295
454
 
296
- rotate=2*((self.t<np.pi/2)-0.5)[:,None]
455
+ del vec_orig
456
+ del mat_pt
457
+
458
+ rotate=2*((t<np.pi/2)-0.5)[:,None]
297
459
  if polar:
298
- xx=np.cos(self.p)[:,None]*self.vec_rot[:,:,0]-rotate*np.sin(self.p)[:,None]*self.vec_rot[:,:,1]
299
- yy=-np.sin(self.p)[:,None]*self.vec_rot[:,:,0]-rotate*np.cos(self.p)[:,None]*self.vec_rot[:,:,1]
460
+ xx=np.cos(p)[:,None]*vec_rot[:,:,0]-rotate*np.sin(p)[:,None]*vec_rot[:,:,1]
461
+ yy=-np.sin(p)[:,None]*vec_rot[:,:,0]-rotate*np.cos(p)[:,None]*vec_rot[:,:,1]
300
462
  else:
301
- xx=self.vec_rot[:,:,0]
302
- yy=self.vec_rot[:,:,1]
463
+ xx=vec_rot[:,:,0]
464
+ yy=vec_rot[:,:,1]
465
+
466
+ del vec_rot
467
+ del rotate
468
+ del t
469
+ del p
303
470
 
304
- self.w_idx,self.w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
305
- yy*self.nside*gamma,
306
- allow_extrapolation=allow_extrapolation)
471
+ w_idx,w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
472
+ yy*self.nside*gamma,
473
+ allow_extrapolation=allow_extrapolation)
474
+ '''
475
+ # calib : [Npix, K]
476
+ calib = np.zeros((w_idx.shape[0], w_idx.shape[2]))
477
+ # Hypothèses :
478
+ # w_idx.shape == (Npix, M, K) et w_w.shape == (Npix, M, K)
479
+ Npix, M, K = w_idx.shape
480
+ nb_cols = K
481
+
482
+ # 1) Accumulation par "bincount" avec décalage de ligne
483
+ row_ids = np.arange(Npix, dtype=np.int64)[:, None, None] * nb_cols
484
+ flat_idx = (row_ids + w_idx).ravel() # indices dans [0, Npix*9)
485
+ weights = w_w.ravel().astype(np.float64) # ou dtype de ton choix
486
+
487
+ calib = np.bincount(flat_idx, weights, minlength=Npix*nb_cols)\
488
+ .reshape(Npix, nb_cols)
489
+
490
+ # 2) Réinjection dans norm_a selon w_idx
491
+ norm_a = calib[np.arange(Npix)[:, None, None], w_idx]
492
+
493
+ w_w /= norm_a
494
+ w_w = np.clip(w_w,0.0,1.0)
307
495
 
496
+ w_w[np.isnan(w_w)]=0.0
497
+ '''
498
+ #del xx
499
+ #del yy
500
+
501
+ return idx_nn,w_idx,w_w,xx,yy
502
+
503
+ def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True,return_index=False):
504
+
505
+ idx_nn,w_idx,w_w = self.make_idx_weights_from_one_cell_ids(self.cell_ids,
506
+ polar=polar,
507
+ gamma=gamma,
508
+ device=device,
509
+ allow_extrapolation=allow_extrapolation)
510
+
308
511
  # Ensure types/devices
309
- self.idx_nn = torch.Tensor(self.idx_nn).to(device=device, dtype=torch.long)
310
- self.w_idx = torch.Tensor(self.w_idx).to(device=device, dtype=torch.long)
311
- self.w_w = torch.Tensor(self.w_w).to(device=device, dtype=torch.float64)
512
+ self.idx_nn = torch.Tensor(idx_nn).to(device=device, dtype=torch.long)
513
+ self.w_idx = torch.Tensor(w_idx).to(device=device, dtype=torch.long)
514
+ self.w_w = torch.Tensor(w_w).to(device=device, dtype=self.dtype)
312
515
 
313
516
  def _grid_index(self, xi, yi):
314
517
  """
@@ -410,108 +613,169 @@ class HOrientedConvol:
410
613
  idx = np.stack([i00, i10, i01, i11], axis=1).astype(np.int64)
411
614
 
412
615
  return idx, w
616
+
617
+ # --- Add inside class HOrientedConvol, just above Convol_torch ---
618
+ def _convol_single(self, im1: torch.Tensor, ww: torch.Tensor, cell_ids=None, nside=None):
619
+ """
620
+ Single-sample path. im1: (1, C_i, Npix_1). Returns (1, C_o, Npix_1).
621
+ """
622
+ if not isinstance(im1, torch.Tensor):
623
+ im1 = torch.as_tensor(im1, device=self.device, dtype=self.dtype)
624
+ if not isinstance(ww, torch.Tensor):
625
+ ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
626
+ assert im1.ndim == 3 and im1.shape[0] == 1, f"expected (1, C_i, Npix), got {tuple(im1.shape)}"
413
627
 
414
- def Convol_torch(self, im, ww):
628
+ # Reuse the existing Convol_torch core by faking B=1 shapes.
629
+ # We call the existing (batched) implementation with B=1.
630
+ return self.Convol_torch(im1, ww, cell_ids=cell_ids, nside=nside) # returns (1, C_o, Npix_1)
631
+
632
+ # --- Replace the first lines of Convol_torch with a dispatcher ---
633
+ def Convol_torch(self, im, ww, cell_ids=None, nside=None):
415
634
  """
416
- Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
417
-
418
- Parameters
419
- ----------
420
- im : Tensor, shape (B, C_i, Npix)
421
- Input features per pixel for a batch of B samples.
422
- ww : Tensor
423
- Base mixing weights, indexed along its 'M' dimension by self.w_idx.
424
- Supported shapes:
425
- (C_i, C_o, M)
426
- (C_i, C_o, M, S)
427
- (B, C_i, C_o, M)
428
- (B, C_i, C_o, M, S)
429
-
430
- Class members (already tensors; will be aligned to im.device/dtype):
431
- -------------------------------------------------------------------
432
- self.idx_nn : LongTensor, shape (Npix, P)
433
- For each center pixel, the P neighbor indices into the Npix axis of `im`.
434
- (P = K*K for a KxK neighborhood.)
435
- self.w_idx : LongTensor, shape (Npix, P) or (Npix, S, P)
436
- Indices along the 'M' dimension of ww, per (center[, sector], neighbor).
437
- self.w_w : Tensor, shape (Npix, P) or (Npix, S, P)
438
- Additional scalar weights per neighbor (same layout as w_idx).
439
-
440
- Returns
441
- -------
442
- out : Tensor, shape (B, C_o, Npix)
443
- Aggregated output per center pixel for each batch sample.
635
+ Batched KERNELSZxKERNELSZ aggregation.
636
+
637
+ Accepts either:
638
+ - im: Tensor (B, C_i, Npix) with one shared or per-batch (B,Npix) cell_ids
639
+ - im: list/tuple of Tensors, each (C_i, Npix_b), with cell_ids a list of arrays
444
640
  """
445
- # ---- Basic checks ----
641
+ import torch
642
+
643
+ # (A) Variable-length per-sample path: im is a list/tuple OR cell_ids is a list/tuple
644
+ if isinstance(im, (list, tuple)) or isinstance(cell_ids, (list, tuple)):
645
+ # Normalize to lists
646
+ im_list = im if isinstance(im, (list, tuple)) else [im]
647
+ cid_list = cell_ids if isinstance(cell_ids, (list, tuple)) else [cell_ids] * len(im_list)
648
+ assert len(im_list) == len(cid_list), "im list and cell_ids list must have same length"
649
+
650
+ outs = []
651
+ for xb, cb in zip(im_list, cid_list):
652
+ # xb: (C_i, Npix_b) -> (1, C_i, Npix_b)
653
+ if not torch.is_tensor(xb):
654
+ xb = torch.as_tensor(xb, device=self.device, dtype=self.dtype)
655
+ if xb.dim() == 2:
656
+ xb = xb.unsqueeze(0)
657
+ elif xb.dim() != 3 or xb.shape[0] != 1:
658
+ raise ValueError(f"Each sample must be (C,N) or (1,C,N); got {tuple(xb.shape)}")
659
+
660
+ yb = self._convol_single(xb, ww, cell_ids=cb, nside=nside) # (1, C_o, Npix_b)
661
+ outs.append(yb.squeeze(0)) # -> (C_o, Npix_b)
662
+ return outs # List[Tensor], each (C_o, Npix_b)
663
+
664
+ # (B) Standard fixed-length batched path (your current implementation)
665
+ # ... keep your existing Convol_torch body from here unchanged ...
666
+ # (paste your current function body starting from the type casting and assertions)
667
+
668
+ # ---- Basic checks / casting ----
669
+ if not isinstance(im, torch.Tensor):
670
+ im = torch.as_tensor(im, device=self.device, dtype=self.dtype)
671
+ if not isinstance(ww, torch.Tensor):
672
+ ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
673
+
446
674
  assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
447
- assert ww.shape[2]==self.KERNELSZ*self.KERNELSZ, f"`ww` must be (C_i, C_o, KERNELSZ*KERNELSZ), got {tuple(ww.shape)}"
448
-
449
675
  B, C_i, Npix = im.shape
450
676
  device = im.device
451
677
  dtype = im.dtype
452
-
453
- # Align class tensors to device/dtype
454
- idx_nn = self.idx_nn.to(device=device, dtype=torch.long) # (Npix, P)
455
- w_idx = self.w_idx.to(device=device, dtype=torch.long) # (Npix, P) or (Npix, S, P)
456
- w_w = self.w_w.to(device=device, dtype=dtype) # (Npix, P) or (Npix, S, P)
457
-
458
- # Neighbor count P inferred from idx_nn
459
- assert idx_nn.ndim == 2 and idx_nn.size(0) == Npix, \
460
- f"`idx_nn` must be (Npix, P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
461
- P = idx_nn.size(1)
462
-
463
- # ---- 1) Gather neighbor values from im along the Npix dimension -> (B, C_i, Npix, P)
464
- # im: (B,C_i,Npix) -> (B,C_i,Npix,1); idx: (1,1,Npix,P) broadcast over (B,C_i)
678
+
679
+ # ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
680
+ # target shapes:
681
+ # idx_nn_eff : (B, Npix, P)
682
+ # w_idx_eff : (B, Npix, S, P)
683
+ # w_w_eff : (B, Npix, S, P)
684
+ if cell_ids is not None:
685
+ # ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
686
+ # Normaliser: accepter Tensor, ndarray ou list/tuple de 1 élément (cas var-length, B=1)
687
+ if isinstance(cell_ids, (list, tuple)):
688
+ # liste d'ids (souvent longueur 1 en var-length)
689
+ if len(cell_ids) == 1:
690
+ cid = np.asarray(cell_ids[0])[None, :] # -> (1, Npix)
691
+ else:
692
+ # si jamais >1, on essaie d'empiler (doit avoir même Npix par élément)
693
+ cid = np.stack([np.asarray(c) for c in cell_ids], axis=0)
694
+ elif torch.is_tensor(cell_ids):
695
+ c = cell_ids.detach().cpu().numpy()
696
+ cid = c if c.ndim != 1 else c[None, :] # uniformiser en 2D quand B=1
697
+ else:
698
+ c = np.asarray(cell_ids)
699
+ cid = c if c.ndim != 1 else c[None, :]
700
+
701
+ # cid est maintenant (B, Npix)
702
+ idx_nn_eff, w_idx_eff, w_w_eff = self.make_idx_weights_from_cell_ids(
703
+ cid, nside, device=device
704
+ )
705
+ # shapes: (B, Npix, P), (B, Npix, S, P|P), (B, Npix, S, P|P)
706
+ P = idx_nn_eff.shape[-1]
707
+ S = w_idx_eff.shape[-2] if w_idx_eff.ndim == 4 else 1
708
+
709
+ # s’assurer des dtypes/devices
710
+ idx_nn_eff = torch.as_tensor(idx_nn_eff, device=device, dtype=torch.long)
711
+ w_idx_eff = torch.as_tensor(w_idx_eff, device=device, dtype=torch.long)
712
+ w_w_eff = torch.as_tensor(w_w_eff, device=device, dtype=dtype)
713
+ else:
714
+ # Use precomputed (shared for batch)
715
+ if self.w_idx is None:
716
+ if self.cell_ids.ndim==1:
717
+ l_cell=self.cell_ids[None,:]
718
+ else:
719
+ l_cell=self.cell_ids
720
+
721
+ idx_nn,w_idx,w_w = self.make_idx_weights_from_cell_ids(
722
+ l_cell,
723
+ polar=self.polar,
724
+ gamma=self.gamma,
725
+ device=self.device,
726
+ allow_extrapolation=self.allow_extrapolation)
727
+
728
+ self.idx_nn = idx_nn
729
+ self.w_idx = w_idx
730
+ self.w_w = w_w
731
+ else:
732
+ idx_nn = self.idx_nn # (Npix,P)
733
+ w_idx = self.w_idx # (Npix,P) or (Npix,S,P)
734
+ w_w = self.w_w # (Npix,P) or (Npix,S,P)
735
+
736
+ #assert idx_nn.ndim == 3 and idx_nn.size(1) == Npix, \
737
+ # f"`idx_nn` must be (B,Npix,P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
738
+
739
+ P = idx_nn.size(-1)
740
+
741
+ if w_idx.ndim == 3:
742
+ S = 1
743
+ w_idx_eff = w_idx[:, :, None, :] # (B,Npix,1,P)
744
+ w_w_eff = w_w[:, :, None, :] # (B,Npix,1,P)
745
+ elif w_idx.ndim == 4:
746
+ S = w_idx.size(2)
747
+ w_idx_eff = w_idx # (B,Npix,S,P)
748
+ w_w_eff = w_w # (B,Npix,S,P)
749
+ else:
750
+ raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
751
+ idx_nn_eff = idx_nn # (B,Npix,P)
752
+
753
+ # ---- 1) Gather neighbor values from im along Npix -> (B, C_i, Npix, P)
465
754
  rim = torch.take_along_dim(
466
- im.unsqueeze(-1),
467
- idx_nn.unsqueeze(0).unsqueeze(0),
755
+ im.unsqueeze(-1), # (B, C_i, Npix, 1)
756
+ idx_nn_eff[:, None, :, :], # (B, 1, Npix, P)
468
757
  dim=2
469
- ) # (B, C_i, Npix, P)
470
-
471
- # ---- 2) Normalize w_idx / w_w to include a sector dim S ----
472
- # Target layout: (Npix, S, P)
473
- if w_idx.ndim == 2:
474
- # (Npix, P) -> add sector dim S=1
475
- assert w_idx.size(0) == Npix and w_idx.size(1) == P
476
- w_idx_eff = w_idx.unsqueeze(1) # (Npix, 1, P)
477
- w_w_eff = w_w.unsqueeze(1) # (Npix, 1, P)
478
- S = 1
479
- elif w_idx.ndim == 3:
480
- # (Npix, S, P)
481
- Npix_, S, P_ = w_idx.shape
482
- assert Npix_ == Npix and P_ == P, \
483
- f"`w_idx` must be (Npix,S,P) with Npix={Npix}, P={P}, got {tuple(w_idx.shape)}"
484
- assert w_w.shape == w_idx.shape, "`w_w` must match `w_idx` shape"
485
- w_idx_eff = w_idx
486
- w_w_eff = w_w
487
- else:
488
- raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
489
-
490
- # ---- 3) Normalize ww to (B, C_i, C_o, M, S) for uniform gather ----
758
+ )
759
+
760
+ # ---- 2) Normalize ww to (B, C_i, C_o, M, S)
491
761
  if ww.ndim == 3:
492
- # (C_i, C_o, M) -> (B, C_i, C_o, M, S)
493
762
  C_i_w, C_o, M = ww.shape
494
763
  assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
495
764
  ww_eff = ww.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, -1, S)
496
-
497
765
  elif ww.ndim == 4:
498
- # Could be (C_i, C_o, M, S) or (B, C_i, C_o, M)
499
766
  if ww.shape[0] == C_i and ww.shape[1] != C_i:
500
- # (C_i, C_o, M, S) -> (B, C_i, C_o, M, S)
767
+ # (C_i, C_o, M, S)
501
768
  C_i_w, C_o, M, S_w = ww.shape
502
769
  assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
503
770
  assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
504
771
  ww_eff = ww.unsqueeze(0).expand(B, -1, -1, -1, -1)
505
772
  elif ww.shape[0] == B:
506
- # (B, C_i, C_o, M) -> (B, C_i, C_o, M, S)
773
+ # (B, C_i, C_o, M)
507
774
  _, C_i_w, C_o, M = ww.shape
508
775
  assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
509
776
  ww_eff = ww.unsqueeze(-1).expand(-1, -1, -1, -1, S)
510
777
  else:
511
- raise ValueError(
512
- f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)"
513
- )
514
-
778
+ raise ValueError(f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)")
515
779
  elif ww.ndim == 5:
516
780
  # (B, C_i, C_o, M, S)
517
781
  assert ww.shape[0] == B and ww.shape[1] == C_i, "ww batch/C_i mismatch"
@@ -520,27 +784,150 @@ class HOrientedConvol:
520
784
  ww_eff = ww
521
785
  else:
522
786
  raise ValueError(f"Unsupported ww shape {tuple(ww.shape)}")
523
-
524
- # ---- 4) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
525
- idx_exp = w_idx_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1,1,1,Npix,S,P)
787
+
788
+ # --- Sanitize shapes: ensure w_idx_eff / w_w_eff == (B, Npix, S, P)
789
+
790
+ # ---- 3) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
791
+ idx_exp = w_idx_eff[:, None, None, :, :, :] # (B,1,1,Npix,S,P)
526
792
  rw = torch.take_along_dim(
527
- ww_eff.unsqueeze(-1), # (B, C_i, C_o, M, S, 1)
528
- idx_exp, # (1,1,1,Npix,S,P) -> broadcast
529
- dim=3 # gather along M
793
+ ww_eff.unsqueeze(-1), # (B,C_i,C_o,M,S,1)
794
+ idx_exp, # (B,1,1,Npix,S,P)
795
+ dim=3 # gather along M
530
796
  ) # -> (B, C_i, C_o, Npix, S, P)
797
+ # ---- 4) Apply extra neighbor weights ----
798
+ rw = rw * w_w_eff[:, None, None, :, :, :] # (B, C_i, C_o, Npix, S, P)
799
+ # ---- 5) Combine neighbor values and weights ----
800
+ rim_exp = rim[:, :, None, :, None, :] # (B, C_i, 1, Npix, 1, P)
801
+ out_ci = (rim_exp * rw).sum(dim=-1) # sum over P -> (B, C_i, C_o, Npix, S)
802
+ out_ci = out_ci.sum(dim=-1) # sum over S -> (B, C_i, C_o, Npix)
803
+ out = out_ci.sum(dim=1) # sum over C_i -> (B, C_o, Npix)
804
+
805
+ return out
531
806
 
532
- # ---- 5) Apply extra neighbor weights ----
533
- rw = rw * w_w_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (B, C_i, C_o, Npix, S, P)
534
-
535
- # ---- 6) Combine neighbor values and weights ----
536
- # rim: (B, C_i, Npix, P) -> expand to (B, C_i, 1, Npix, 1, P)
537
- rim_exp = rim[:, :, None, :, None, :]
538
- # sum over neighbors (P), then over sectors (S), then over input channels (C_i)
539
- out_ci = (rim_exp * rw).sum(dim=-1) # (B, C_i, C_o, Npix, S)
540
- out_ci = out_ci.sum(dim=-1) # (B, C_i, C_o, Npix)
541
- out = out_ci.sum(dim=1) # (B, C_o, Npix)
807
+ def _to_numpy_1d(self, ids):
808
+ """Return a 1D numpy array of int64 for a single set of cell ids."""
809
+ import numpy as np, torch
810
+ if isinstance(ids, np.ndarray):
811
+ return ids.reshape(-1).astype(np.int64, copy=False)
812
+ if torch.is_tensor(ids):
813
+ return ids.detach().cpu().to(torch.long).view(-1).numpy()
814
+ # python list/tuple of ints
815
+ return np.asarray(ids, dtype=np.int64).reshape(-1)
816
+
817
+ def _is_varlength_batch(self, ids):
818
+ """
819
+ True if ids is a list/tuple of per-sample id arrays (var-length batch).
820
+ False if ids is a single array/tensor of ids (shared for whole batch).
821
+ """
822
+ import numpy as np, torch
823
+ if isinstance(ids, (list, tuple)):
824
+ return True
825
+ if isinstance(ids, np.ndarray) and ids.ndim == 2:
826
+ # This would be a dense (B, Npix) matrix -> NOT var-length list
827
+ return False
828
+ if torch.is_tensor(ids) and ids.dim() == 2:
829
+ return False
830
+ return False
831
+
832
+ def Down(self, im, cell_ids=None, nside=None,max_poll=False):
833
+ """
834
+ If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
835
+ If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
836
+ """
837
+ if self.f is None:
838
+ if self.dtype==torch.float64:
839
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
840
+ else:
841
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
842
+
843
+ if cell_ids is None:
844
+ dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=False)
845
+ return dim,cdim
846
+
847
+ if nside is None:
848
+ nside = self.nside
849
+
850
+ # var-length mode: list/tuple of ids, one per sample
851
+ if self._is_varlength_batch(cell_ids):
852
+ outs, outs_ids = [], []
853
+ B = len(cell_ids)
854
+ for b in range(B):
855
+ cid_b = self._to_numpy_1d(cell_ids[b])
856
+ # extraire le bon échantillon d'`im`
857
+ if torch.is_tensor(im):
858
+ xb = im[b:b+1] # (1, C, N_b)
859
+ yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
860
+ outs.append(yb.squeeze(0)) # (C, N_b')
861
+ else:
862
+ # si im est déjà une liste de (C, N_b)
863
+ xb = im[b]
864
+ yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
865
+ outs.append(yb.squeeze(0))
866
+ outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
867
+ return outs, outs_ids
868
+
869
+ # grille commune (un seul vecteur d'ids)
870
+ cid = self._to_numpy_1d(cell_ids)
871
+ return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
872
+
873
+ def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
874
+ """
875
+ If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
876
+ If they are lists (var-length per sample) -> return list[Tensor].
877
+ """
878
+ if self.f is None:
879
+ if self.dtype==torch.float64:
880
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
881
+ else:
882
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
883
+
884
+ if cell_ids is None:
885
+ dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
886
+ return dim
887
+
888
+ if nside is None:
889
+ nside = self.nside
890
+
891
+ # var-length: listes parallèles
892
+ if self._is_varlength_batch(cell_ids):
893
+ assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
894
+ "In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
895
+ outs = []
896
+ B = len(cell_ids)
897
+ for b in range(B):
898
+ cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
899
+ ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
900
+ if torch.is_tensor(im):
901
+ xb = im[b:b+1] # (1, C, N_b_coarse)
902
+ yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
903
+ o_cell_ids=ocid_b, force_init_index=True)
904
+ outs.append(yb.squeeze(0)) # (C, N_b_fine)
905
+ else:
906
+ xb = im[b] # (C, N_b_coarse)
907
+ yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
908
+ o_cell_ids=ocid_b, force_init_index=True)
909
+ outs.append(yb.squeeze(0))
910
+ return outs
911
+
912
+ # grille commune
913
+ cid = self._to_numpy_1d(cell_ids)
914
+ ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
915
+ return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
916
+ o_cell_ids=ocid, force_init_index=True)
917
+
918
+ def to_tensor(self,x):
919
+ if self.f is None:
920
+ if self.dtype==torch.float64:
921
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
922
+ else:
923
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
924
+ return self.f.backend.bk_cast(x)
542
925
 
543
- return out
926
+ def to_numpy(self,x):
927
+ if isinstance(x,np.ndarray):
928
+ return x
929
+ return x.cpu().numpy()
930
+
544
931
 
545
932
 
546
933