foscat 2025.8.3__py3-none-any.whl → 2025.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/HOrientedConvol.py CHANGED
@@ -3,16 +3,36 @@ import matplotlib.pyplot as plt
3
3
  import healpy as hp
4
4
  from scipy.sparse import csr_array
5
5
  import torch
6
+ import foscat.scat_cov as sc
6
7
  from scipy.spatial import cKDTree
7
8
 
8
9
  class HOrientedConvol:
9
- def __init__(self,nside,KERNELSZ,cell_ids=None,nest=True):
10
+ def __init__(self,
11
+ nside,
12
+ KERNELSZ,
13
+ cell_ids=None,
14
+ nest=True,
15
+ device='cuda',
16
+ dtype='float64',
17
+ polar=False,
18
+ gamma=1.0,
19
+ allow_extrapolation=True,
20
+ no_cell_ids=False,
21
+ ):
22
+
23
+ if dtype=='float64':
24
+ self.dtype=torch.float64
25
+ else:
26
+ self.dtype=torch.float32
10
27
 
11
28
  if KERNELSZ % 2 == 0:
12
29
  raise ValueError(f"N must be odd so that coordinates are integers from -K..K; got N={KERNELSZ}.")
13
30
 
14
31
  self.local_test=False
15
-
32
+
33
+ if no_cell_ids==True:
34
+ cell_ids=np.arange(10)
35
+
16
36
  if cell_ids is None:
17
37
  self.cell_ids=np.arange(12*nside**2)
18
38
 
@@ -28,37 +48,84 @@ class HOrientedConvol:
28
48
  self.cell_ids=cell_ids
29
49
 
30
50
  self.local_test=True
31
-
32
- idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
33
- KERNELSZ*KERNELSZ,
34
- nside,
35
- nest=nest,
36
- )
37
51
 
52
+ if self.cell_ids.ndim==1:
53
+ idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
54
+ KERNELSZ*KERNELSZ,
55
+ nside,
56
+ nest=nest,
57
+ )
58
+ else:
59
+ idx_nn = []
60
+ for k in range(self.cell_ids.shape[0]):
61
+ idx_nn.append(self.knn_healpix_ckdtree(self.cell_ids[k],
62
+ KERNELSZ*KERNELSZ,
63
+ nside,
64
+ nest=nest,
65
+ ))
66
+ idx_nn=np.stack(idx_nn,0)
67
+
68
+ if self.cell_ids.ndim==1:
69
+ mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids,nest=nest)
38
70
 
39
- mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids,nest=nest)
71
+ if self.local_test:
72
+ t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
73
+ else:
74
+ t,p = hp.pix2ang(nside,idx_nn,nest=True)
75
+
76
+ self.t=t[:,0]
77
+ self.p=p[:,0]
78
+ vec_orig=hp.ang2vec(t,p)
79
+
80
+ self.vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
40
81
 
41
- if self.local_test:
42
- t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
82
+ '''
83
+ if self.local_test:
84
+ idx_nn=self.remap_by_first_column(idx_nn)
85
+ '''
86
+
87
+ del mat_pt
88
+ del vec_orig
43
89
  else:
44
- t,p = hp.pix2ang(nside,idx_nn,nest=True)
45
90
 
46
- vec_orig=hp.ang2vec(t,p)
91
+ t,p,vec_rot = [],[],[]
92
+
93
+ for k in range(self.cell_ids.shape[0]):
94
+ mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids[k],nest=nest)
95
+
96
+ lt,lp = hp.pix2ang(nside,self.cell_ids[k,idx_nn[k]],nest=True)
97
+
98
+ vec_orig=hp.ang2vec(lt,lp)
99
+
100
+ l_vec_rot=np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
101
+ vec_rot.append(l_vec_rot)
102
+
103
+ del vec_orig
104
+ del mat_pt
105
+
106
+ t.append(lt[:,0])
107
+ p.append(lp[:,0])
47
108
 
48
- self.vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
109
+
110
+ self.t=np.stack(t,0)
111
+ self.p=np.stack(p,0)
112
+ self.vec_rot=np.stack(vec_rot,0)
49
113
 
50
- '''
51
- if self.local_test:
52
- idx_nn=self.remap_by_first_column(idx_nn)
53
- '''
114
+ del t
115
+ del p
116
+ del vec_rot
117
+
118
+ self.polar=polar
119
+ self.gamma=gamma
120
+ self.device=device
121
+ self.allow_extrapolation=allow_extrapolation
122
+ self.w_idx=None
54
123
 
55
- del mat_pt
56
- del vec_orig
57
- self.t=t[:,0]
58
- self.p=p[:,0]
59
124
  self.idx_nn=idx_nn
60
125
  self.nside=nside
61
126
  self.KERNELSZ=KERNELSZ
127
+ self.nest=nest
128
+ self.f=None
62
129
 
63
130
  def remap_by_first_column(self,idx: np.ndarray) -> np.ndarray:
64
131
  """
@@ -290,25 +357,95 @@ class HOrientedConvol:
290
357
 
291
358
  return csr_array((w, (indice_1_0, indice_1_1)), shape=(12*self.nside**2, 12*self.nside**2*NORIENT))
292
359
 
293
-
294
- def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True):
360
+ def make_idx_weights_from_cell_ids(self,i_cell_ids,
361
+ polar=False,
362
+ gamma=1.0,
363
+ device='cuda',
364
+ allow_extrapolation=True):
365
+ if len(i_cell_ids.shape)<2:
366
+ cell_ids=i_cell_ids
367
+ n_cids=1
368
+ else:
369
+ cell_ids=i_cell_ids[0]
370
+ n_cids=i_cell_ids.shape[0]
371
+
372
+ idx_nn,w_idx,w_w = [],[],[]
373
+
374
+ for k in range(n_cids):
375
+ cell_ids=i_cell_ids[k]
376
+ l_idx_nn,l_w_idx,l_w_w = self.make_idx_weights_from_one_cell_ids(cell_ids,
377
+ polar=polar,
378
+ gamma=gamma,
379
+ device=device,
380
+ allow_extrapolation=allow_extrapolation)
381
+ idx_nn.append(l_idx_nn)
382
+ w_idx.append(l_w_idx)
383
+ w_w.append(l_w_w)
384
+
385
+ idx_nn = torch.Tensor(np.stack(idx_nn,0)).to(device=device, dtype=torch.long)
386
+ w_idx = torch.Tensor(np.stack(w_idx,0)).to(device=device, dtype=torch.long)
387
+ w_w = torch.Tensor(np.stack(w_w,0)).to(device=device, dtype=self.dtype)
388
+
389
+ return idx_nn,w_idx,w_w
390
+
391
+ def make_idx_weights_from_one_cell_ids(self,
392
+ cell_ids,
393
+ polar=False,
394
+ gamma=1.0,
395
+ device='cuda',
396
+ allow_extrapolation=True):
397
+
398
+ idx_nn = self.knn_healpix_ckdtree(cell_ids,
399
+ self.KERNELSZ*self.KERNELSZ,
400
+ self.nside,
401
+ nest=self.nest,
402
+ )
403
+
404
+ mat_pt=self.rotation_matrices_from_healpix(self.nside,cell_ids,nest=self.nest)
405
+
406
+ t,p = hp.pix2ang(self.nside,cell_ids[idx_nn],nest=self.nest)
407
+
408
+ vec_orig=hp.ang2vec(t,p)
409
+
410
+ vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
411
+
412
+ del vec_orig
413
+ del mat_pt
295
414
 
296
- rotate=2*((self.t<np.pi/2)-0.5)[:,None]
415
+ rotate=2*((t<np.pi/2)-0.5)[:,None]
297
416
  if polar:
298
- xx=np.cos(self.p)[:,None]*self.vec_rot[:,:,0]-rotate*np.sin(self.p)[:,None]*self.vec_rot[:,:,1]
299
- yy=-np.sin(self.p)[:,None]*self.vec_rot[:,:,0]-rotate*np.cos(self.p)[:,None]*self.vec_rot[:,:,1]
417
+ xx=np.cos(p)[:,None]*vec_rot[:,:,0]-rotate*np.sin(p)[:,None]*vec_rot[:,:,1]
418
+ yy=-np.sin(p)[:,None]*vec_rot[:,:,0]-rotate*np.cos(p)[:,None]*vec_rot[:,:,1]
300
419
  else:
301
- xx=self.vec_rot[:,:,0]
302
- yy=self.vec_rot[:,:,1]
420
+ xx=vec_rot[:,:,0]
421
+ yy=vec_rot[:,:,1]
422
+
423
+ del vec_rot
424
+ del rotate
425
+ del t
426
+ del p
303
427
 
304
- self.w_idx,self.w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
305
- yy*self.nside*gamma,
306
- allow_extrapolation=allow_extrapolation)
428
+ w_idx,w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
429
+ yy*self.nside*gamma,
430
+ allow_extrapolation=allow_extrapolation)
307
431
 
432
+ del xx
433
+ del yy
434
+
435
+ return idx_nn,w_idx,w_w
436
+
437
+ def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True,return_index=False):
438
+
439
+ idx_nn,w_idx,w_w = self.make_idx_weights_from_one_cell_ids(self.cell_ids,
440
+ polar=polar,
441
+ gamma=gamma,
442
+ device=device,
443
+ allow_extrapolation=allow_extrapolation)
444
+
308
445
  # Ensure types/devices
309
- self.idx_nn = torch.Tensor(self.idx_nn).to(device=device, dtype=torch.long)
310
- self.w_idx = torch.Tensor(self.w_idx).to(device=device, dtype=torch.long)
311
- self.w_w = torch.Tensor(self.w_w).to(device=device, dtype=torch.float64)
446
+ self.idx_nn = torch.Tensor(idx_nn).to(device=device, dtype=torch.long)
447
+ self.w_idx = torch.Tensor(w_idx).to(device=device, dtype=torch.long)
448
+ self.w_w = torch.Tensor(w_w).to(device=device, dtype=self.dtype)
312
449
 
313
450
  def _grid_index(self, xi, yi):
314
451
  """
@@ -411,7 +548,184 @@ class HOrientedConvol:
411
548
 
412
549
  return idx, w
413
550
 
414
- def Convol_torch(self, im, ww):
551
+ def Convol_torch(self, im, ww, cell_ids=None, nside=None):
552
+ """
553
+ Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
554
+
555
+ Parameters
556
+ ----------
557
+ im : Tensor, shape (B, C_i, Npix)
558
+ ww : Tensor, shapes supported:
559
+ (C_i, C_o, M) | (C_i, C_o, M, S) | (B, C_i, C_o, M) | (B, C_i, C_o, M, S)
560
+ cell_ids : ndarray or Tensor
561
+ - None: use precomputed self.idx_nn / self.w_idx / self.w_w (shared for batch).
562
+ - (Npix,): recompute once (shared for batch).
563
+ - (B, Npix): recompute per-sample (different for each b).
564
+
565
+ Returns
566
+ -------
567
+ out : Tensor, shape (B, C_o, Npix)
568
+ """
569
+ import torch
570
+
571
+ # ---- Basic checks / casting ----
572
+ if not isinstance(im, torch.Tensor):
573
+ im = torch.as_tensor(im, device=self.device, dtype=self.dtype)
574
+ if not isinstance(ww, torch.Tensor):
575
+ ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
576
+
577
+ assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
578
+ B, C_i, Npix = im.shape
579
+ device = im.device
580
+ dtype = im.dtype
581
+
582
+ # ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
583
+ # target shapes:
584
+ # idx_nn_eff : (B, Npix, P)
585
+ # w_idx_eff : (B, Npix, S, P)
586
+ # w_w_eff : (B, Npix, S, P)
587
+ if cell_ids is not None:
588
+ # to numpy for your make_idx_weights_from_cell_ids helper if needed
589
+ if isinstance(cell_ids, torch.Tensor):
590
+ cid = cell_ids.detach().to("cpu").numpy()
591
+ else:
592
+ cid = cell_ids
593
+
594
+ if cid.ndim == 1:
595
+ # single set of ids for the whole batch
596
+ idx_nn, w_idx, w_w = self.make_idx_weights_from_cell_ids(cid, nside, device=device)
597
+ assert idx_nn.ndim == 2, "idx_nn expected (Npix,P)"
598
+ P = idx_nn.shape[1]
599
+ if w_idx.ndim == 2:
600
+ # (Npix,P) -> (B,Npix,1,P)
601
+ S = 1
602
+ w_idx_eff = w_idx[None, :, None, :].expand(B, -1, -1, -1)
603
+ w_w_eff = w_w[None, :, None, :].expand(B, -1, -1, -1)
604
+ elif w_idx.ndim == 3:
605
+ # (Npix,S,P) -> (B,Npix,S,P)
606
+ S = w_idx.shape[1]
607
+ w_idx_eff = w_idx[None, ...].expand(B, -1, -1, -1)
608
+ w_w_eff = w_w[None, ...].expand(B, -1, -1, -1)
609
+ else:
610
+ raise ValueError(f"Unsupported w_idx shape {tuple(w_idx.shape)}")
611
+ idx_nn_eff = idx_nn[None, ...].expand(B, -1, -1) # (B,Npix,P)
612
+
613
+ elif cid.ndim == 2:
614
+ # per-sample ids
615
+ assert cid.shape[0] == B and cid.shape[1] == Npix, \
616
+ f"cell_ids must be (B,Npix) with B={B},Npix={Npix}, got {cid.shape}"
617
+ S_ref = None
618
+
619
+ idx_nn_eff, w_idx_eff, w_w_eff = self.make_idx_weights_from_cell_ids(cid,
620
+ nside,
621
+ device=device)
622
+ P = idx_nn_eff.shape[-1]
623
+ S = w_idx_eff.shape[-2]
624
+
625
+ else:
626
+ raise ValueError(f"Unsupported cell_ids shape {cid.shape}")
627
+
628
+ # ensure tensors on right device/dtype
629
+ idx_nn_eff = torch.as_tensor(idx_nn_eff, device=device, dtype=torch.long)
630
+ w_idx_eff = torch.as_tensor(w_idx_eff, device=device, dtype=torch.long)
631
+ w_w_eff = torch.as_tensor(w_w_eff, device=device, dtype=dtype)
632
+
633
+ else:
634
+ # Use precomputed (shared for batch)
635
+ if self.w_idx is None:
636
+
637
+ if self.cell_ids.ndim==1:
638
+ l_cell=self.cell_ids[None,:]
639
+ else:
640
+ l_cell=self.cell_ids
641
+
642
+ idx_nn,w_idx,w_w = self.make_idx_weights_from_cell_ids(
643
+ l_cell,
644
+ polar=self.polar,
645
+ gamma=self.gamma,
646
+ device=self.device,
647
+ allow_extrapolation=self.allow_extrapolation)
648
+
649
+ self.idx_nn = idx_nn
650
+ self.w_idx = w_idx
651
+ self.w_w = w_w
652
+ else:
653
+ idx_nn = self.idx_nn # (Npix,P)
654
+ w_idx = self.w_idx # (Npix,P) or (Npix,S,P)
655
+ w_w = self.w_w # (Npix,P) or (Npix,S,P)
656
+
657
+ #assert idx_nn.ndim == 3 and idx_nn.size(1) == Npix, \
658
+ # f"`idx_nn` must be (B,Npix,P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
659
+
660
+ P = idx_nn.size(-1)
661
+
662
+ if w_idx.ndim == 3:
663
+ S = 1
664
+ w_idx_eff = w_idx[:, :, None, :] # (B,Npix,1,P)
665
+ w_w_eff = w_w[:, :, None, :] # (B,Npix,1,P)
666
+ elif w_idx.ndim == 4:
667
+ S = w_idx.size(2)
668
+ w_idx_eff = w_idx # (B,Npix,S,P)
669
+ w_w_eff = w_w # (B,Npix,S,P)
670
+ else:
671
+ raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
672
+ idx_nn_eff = idx_nn # (B,Npix,P)
673
+
674
+ # ---- 1) Gather neighbor values from im along Npix -> (B, C_i, Npix, P)
675
+ rim = torch.take_along_dim(
676
+ im.unsqueeze(-1), # (B, C_i, Npix, 1)
677
+ idx_nn_eff[:, None, :, :], # (B, 1, Npix, P)
678
+ dim=2
679
+ )
680
+
681
+ # ---- 2) Normalize ww to (B, C_i, C_o, M, S)
682
+ if ww.ndim == 3:
683
+ C_i_w, C_o, M = ww.shape
684
+ assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
685
+ ww_eff = ww.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, -1, S)
686
+ elif ww.ndim == 4:
687
+ if ww.shape[0] == C_i and ww.shape[1] != C_i:
688
+ # (C_i, C_o, M, S)
689
+ C_i_w, C_o, M, S_w = ww.shape
690
+ assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
691
+ assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
692
+ ww_eff = ww.unsqueeze(0).expand(B, -1, -1, -1, -1)
693
+ elif ww.shape[0] == B:
694
+ # (B, C_i, C_o, M)
695
+ _, C_i_w, C_o, M = ww.shape
696
+ assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
697
+ ww_eff = ww.unsqueeze(-1).expand(-1, -1, -1, -1, S)
698
+ else:
699
+ raise ValueError(f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)")
700
+ elif ww.ndim == 5:
701
+ # (B, C_i, C_o, M, S)
702
+ assert ww.shape[0] == B and ww.shape[1] == C_i, "ww batch/C_i mismatch"
703
+ _, _, _, M, S_w = ww.shape
704
+ assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
705
+ ww_eff = ww
706
+ else:
707
+ raise ValueError(f"Unsupported ww shape {tuple(ww.shape)}")
708
+
709
+ # --- Sanitize shapes: ensure w_idx_eff / w_w_eff == (B, Npix, S, P)
710
+
711
+ # ---- 3) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
712
+ idx_exp = w_idx_eff[:, None, None, :, :, :] # (B,1,1,Npix,S,P)
713
+ rw = torch.take_along_dim(
714
+ ww_eff.unsqueeze(-1), # (B,C_i,C_o,M,S,1)
715
+ idx_exp, # (B,1,1,Npix,S,P)
716
+ dim=3 # gather along M
717
+ ) # -> (B, C_i, C_o, Npix, S, P)
718
+ # ---- 4) Apply extra neighbor weights ----
719
+ rw = rw * w_w_eff[:, None, None, :, :, :] # (B, C_i, C_o, Npix, S, P)
720
+ # ---- 5) Combine neighbor values and weights ----
721
+ rim_exp = rim[:, :, None, :, None, :] # (B, C_i, 1, Npix, 1, P)
722
+ out_ci = (rim_exp * rw).sum(dim=-1) # sum over P -> (B, C_i, C_o, Npix, S)
723
+ out_ci = out_ci.sum(dim=-1) # sum over S -> (B, C_i, C_o, Npix)
724
+ out = out_ci.sum(dim=1) # sum over C_i -> (B, C_o, Npix)
725
+
726
+ return out
727
+
728
+ def Convol_torch_old(self, im, ww,cell_ids=None,nside=None):
415
729
  """
416
730
  Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
417
731
 
@@ -426,7 +740,11 @@ class HOrientedConvol:
426
740
  (C_i, C_o, M, S)
427
741
  (B, C_i, C_o, M)
428
742
  (B, C_i, C_o, M, S)
429
-
743
+
744
+ cell_ids : ndarray
745
+ If cell_ids is not None recompute the index and do not use the precomputed ones.
746
+ Note : The computation is then much longer.
747
+
430
748
  Class members (already tensors; will be aligned to im.device/dtype):
431
749
  -------------------------------------------------------------------
432
750
  self.idx_nn : LongTensor, shape (Npix, P)
@@ -443,18 +761,27 @@ class HOrientedConvol:
443
761
  Aggregated output per center pixel for each batch sample.
444
762
  """
445
763
  # ---- Basic checks ----
764
+ if not isinstance(im,torch.Tensor):
765
+ im=torch.Tensor(im).to(device=self.device, dtype=self.dtype)
766
+ if not isinstance(ww,torch.Tensor):
767
+ ww=torch.Tensor(ww).to(device=self.device, dtype=self.dtype)
768
+
446
769
  assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
770
+
447
771
  assert ww.shape[2]==self.KERNELSZ*self.KERNELSZ, f"`ww` must be (C_i, C_o, KERNELSZ*KERNELSZ), got {tuple(ww.shape)}"
448
772
 
449
773
  B, C_i, Npix = im.shape
450
774
  device = im.device
451
775
  dtype = im.dtype
776
+
777
+ if cell_ids is not None:
452
778
 
453
- # Align class tensors to device/dtype
454
- idx_nn = self.idx_nn.to(device=device, dtype=torch.long) # (Npix, P)
455
- w_idx = self.w_idx.to(device=device, dtype=torch.long) # (Npix, P) or (Npix, S, P)
456
- w_w = self.w_w.to(device=device, dtype=dtype) # (Npix, P) or (Npix, S, P)
457
-
779
+ idx_nn,w_idx,w_w = self.make_idx_weights_from_cell_ids(cell_ids,nside,device=device)
780
+ else:
781
+ idx_nn = self.idx_nn # (Npix, P)
782
+ w_idx = self.w_idx # (Npix, P) or (Npix, S, P)
783
+ w_w = self.w_w # (Npix, P) or (Npix, S, P)
784
+
458
785
  # Neighbor count P inferred from idx_nn
459
786
  assert idx_nn.ndim == 2 and idx_nn.size(0) == Npix, \
460
787
  f"`idx_nn` must be (Npix, P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
@@ -542,5 +869,82 @@ class HOrientedConvol:
542
869
 
543
870
  return out
544
871
 
872
+ def Down(self, im, cell_ids=None,nside=None):
873
+ if self.f is None:
874
+ if self.dtype==torch.float64:
875
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
876
+ else:
877
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
878
+
879
+ if cell_ids is None:
880
+ dim,_ = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside)
881
+ return dim
882
+ else:
883
+ if nside is None:
884
+ nside=self.nside
885
+ if len(cell_ids.shape)==1:
886
+ return self.f.ud_grade_2(im,cell_ids=cell_ids,nside=nside)
887
+ else:
888
+ assert im.shape[0] == cell_ids.shape[0], \
889
+ f"cell_ids and data should have the same batch size (first column), got data={im.shape},cell_ids={cell_ids.shape}"
890
+
891
+ result,result_cell_ids = [],[]
892
+
893
+ for k in range(im.shape[0]):
894
+ r,c = self.f.ud_grade_2(im[k],cell_ids=cell_ids[k],nside=nside)
895
+ result.append(r)
896
+ result_cell_ids.append(c)
897
+
898
+ result = torch.stack(result, dim=0) # (B,...,Npix)
899
+ result_cell_ids = torch.stack(result_cell_ids, dim=0) # (B,Npix)
900
+ return result,result_cell_ids
901
+
902
+ def Up(self, im, cell_ids=None,nside=None,o_cell_ids=None):
903
+ if self.f is None:
904
+ if self.dtype==torch.float64:
905
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
906
+ else:
907
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
908
+
909
+ if cell_ids is None:
910
+ dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
911
+ return dim
912
+ else:
913
+ if nside is None:
914
+ nside=self.nside
915
+ if nside is None:
916
+ nside=self.nside
917
+ if len(cell_ids.shape)==1:
918
+ return self.f.up_grade(im,nside*2,cell_ids=cell_ids,nside=nside,o_cell_ids=o_cell_ids)
919
+ else:
920
+ assert im.shape[0] == cell_ids.shape[0], \
921
+ f"cell_ids and data should have the same batch size (first column), got data={im.shape},cell_ids={cell_ids.shape}"
922
+
923
+ assert im.shape[0] == o_cell_ids.shape[0], \
924
+ f"cell_ids and data should have the same batch size (first column), got data={im.shape},o_cell_ids={o_cell_ids.shape}"
925
+
926
+ result = []
927
+
928
+ for k in range(im.shape[0]):
929
+ r= self.f.up_grade(im[k],nside*2,cell_ids=cell_ids[k],nside=nside,o_cell_ids=o_cell_ids[k])
930
+ result.append(r)
931
+
932
+ result = torch.stack(result, dim=0) # (B,...,Npix)
933
+ return result
934
+
935
+ def to_tensor(self,x):
936
+ if self.f is None:
937
+ if self.dtype==torch.float64:
938
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
939
+ else:
940
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
941
+ return self.f.backend.bk_cast(x)
942
+
943
+ def to_numpy(self,x):
944
+ if isinstance(x,np.ndarray):
945
+ return x
946
+ return x.cpu().numpy()
947
+
948
+
545
949
 
546
950