foscat 2025.8.4__py3-none-any.whl → 2025.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +309 -50
- foscat/FoCUS.py +74 -267
- foscat/HOrientedConvol.py +517 -130
- foscat/HealBili.py +309 -0
- foscat/Plot.py +331 -0
- foscat/SphericalStencil.py +1346 -0
- foscat/UNET.py +470 -179
- foscat/healpix_unet_torch.py +1202 -0
- foscat/scat_cov.py +3 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/METADATA +1 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/RECORD +14 -10
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/WHEEL +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/top_level.txt +0 -0
foscat/HOrientedConvol.py
CHANGED
|
@@ -3,16 +3,37 @@ import matplotlib.pyplot as plt
|
|
|
3
3
|
import healpy as hp
|
|
4
4
|
from scipy.sparse import csr_array
|
|
5
5
|
import torch
|
|
6
|
+
import foscat.scat_cov as sc
|
|
6
7
|
from scipy.spatial import cKDTree
|
|
7
8
|
|
|
8
9
|
class HOrientedConvol:
|
|
9
|
-
def __init__(self,
|
|
10
|
+
def __init__(self,
|
|
11
|
+
nside,
|
|
12
|
+
KERNELSZ,
|
|
13
|
+
cell_ids=None,
|
|
14
|
+
nest=True,
|
|
15
|
+
device='cuda',
|
|
16
|
+
dtype='float64',
|
|
17
|
+
polar=False,
|
|
18
|
+
gamma=1.0,
|
|
19
|
+
allow_extrapolation=True,
|
|
20
|
+
no_cell_ids=False,
|
|
21
|
+
):
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
if dtype=='float64':
|
|
25
|
+
self.dtype=torch.float64
|
|
26
|
+
else:
|
|
27
|
+
self.dtype=torch.float32
|
|
10
28
|
|
|
11
29
|
if KERNELSZ % 2 == 0:
|
|
12
30
|
raise ValueError(f"N must be odd so that coordinates are integers from -K..K; got N={KERNELSZ}.")
|
|
13
31
|
|
|
14
32
|
self.local_test=False
|
|
15
|
-
|
|
33
|
+
|
|
34
|
+
if no_cell_ids==True:
|
|
35
|
+
cell_ids=np.arange(10)
|
|
36
|
+
|
|
16
37
|
if cell_ids is None:
|
|
17
38
|
self.cell_ids=np.arange(12*nside**2)
|
|
18
39
|
|
|
@@ -28,37 +49,84 @@ class HOrientedConvol:
|
|
|
28
49
|
self.cell_ids=cell_ids
|
|
29
50
|
|
|
30
51
|
self.local_test=True
|
|
31
|
-
|
|
32
|
-
idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
|
|
33
|
-
KERNELSZ*KERNELSZ,
|
|
34
|
-
nside,
|
|
35
|
-
nest=nest,
|
|
36
|
-
)
|
|
37
52
|
|
|
53
|
+
if self.cell_ids.ndim==1:
|
|
54
|
+
idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
|
|
55
|
+
KERNELSZ*KERNELSZ,
|
|
56
|
+
nside,
|
|
57
|
+
nest=nest,
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
idx_nn = []
|
|
61
|
+
for k in range(self.cell_ids.shape[0]):
|
|
62
|
+
idx_nn.append(self.knn_healpix_ckdtree(self.cell_ids[k],
|
|
63
|
+
KERNELSZ*KERNELSZ,
|
|
64
|
+
nside,
|
|
65
|
+
nest=nest,
|
|
66
|
+
))
|
|
67
|
+
idx_nn=np.stack(idx_nn,0)
|
|
68
|
+
|
|
69
|
+
if self.cell_ids.ndim==1:
|
|
70
|
+
mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids,nest=nest)
|
|
38
71
|
|
|
39
|
-
|
|
72
|
+
if self.local_test:
|
|
73
|
+
t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
|
|
74
|
+
else:
|
|
75
|
+
t,p = hp.pix2ang(nside,idx_nn,nest=True)
|
|
76
|
+
|
|
77
|
+
self.t=t[:,0]
|
|
78
|
+
self.p=p[:,0]
|
|
79
|
+
vec_orig=hp.ang2vec(t,p)
|
|
40
80
|
|
|
41
|
-
|
|
42
|
-
|
|
81
|
+
self.vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
|
|
82
|
+
|
|
83
|
+
'''
|
|
84
|
+
if self.local_test:
|
|
85
|
+
idx_nn=self.remap_by_first_column(idx_nn)
|
|
86
|
+
'''
|
|
87
|
+
|
|
88
|
+
del mat_pt
|
|
89
|
+
del vec_orig
|
|
43
90
|
else:
|
|
44
|
-
t,p = hp.pix2ang(nside,idx_nn,nest=True)
|
|
45
91
|
|
|
46
|
-
|
|
92
|
+
t,p,vec_rot = [],[],[]
|
|
93
|
+
|
|
94
|
+
for k in range(self.cell_ids.shape[0]):
|
|
95
|
+
mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids[k],nest=nest)
|
|
96
|
+
|
|
97
|
+
lt,lp = hp.pix2ang(nside,self.cell_ids[k,idx_nn[k]],nest=True)
|
|
98
|
+
|
|
99
|
+
vec_orig=hp.ang2vec(lt,lp)
|
|
100
|
+
|
|
101
|
+
l_vec_rot=np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
|
|
102
|
+
vec_rot.append(l_vec_rot)
|
|
103
|
+
|
|
104
|
+
del vec_orig
|
|
105
|
+
del mat_pt
|
|
106
|
+
|
|
107
|
+
t.append(lt[:,0])
|
|
108
|
+
p.append(lp[:,0])
|
|
47
109
|
|
|
48
|
-
|
|
110
|
+
|
|
111
|
+
self.t=np.stack(t,0)
|
|
112
|
+
self.p=np.stack(p,0)
|
|
113
|
+
self.vec_rot=np.stack(vec_rot,0)
|
|
49
114
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
115
|
+
del t
|
|
116
|
+
del p
|
|
117
|
+
del vec_rot
|
|
118
|
+
|
|
119
|
+
self.polar=polar
|
|
120
|
+
self.gamma=gamma
|
|
121
|
+
self.device=device
|
|
122
|
+
self.allow_extrapolation=allow_extrapolation
|
|
123
|
+
self.w_idx=None
|
|
54
124
|
|
|
55
|
-
del mat_pt
|
|
56
|
-
del vec_orig
|
|
57
|
-
self.t=t[:,0]
|
|
58
|
-
self.p=p[:,0]
|
|
59
125
|
self.idx_nn=idx_nn
|
|
60
126
|
self.nside=nside
|
|
61
127
|
self.KERNELSZ=KERNELSZ
|
|
128
|
+
self.nest=nest
|
|
129
|
+
self.f=None
|
|
62
130
|
|
|
63
131
|
def remap_by_first_column(self,idx: np.ndarray) -> np.ndarray:
|
|
64
132
|
"""
|
|
@@ -290,25 +358,160 @@ class HOrientedConvol:
|
|
|
290
358
|
|
|
291
359
|
return csr_array((w, (indice_1_0, indice_1_1)), shape=(12*self.nside**2, 12*self.nside**2*NORIENT))
|
|
292
360
|
|
|
293
|
-
|
|
294
|
-
|
|
361
|
+
def make_idx_weights_from_cell_ids(self,
|
|
362
|
+
i_cell_ids,
|
|
363
|
+
polar=False,
|
|
364
|
+
gamma=1.0,
|
|
365
|
+
device='cuda',
|
|
366
|
+
allow_extrapolation=True):
|
|
367
|
+
"""
|
|
368
|
+
Accept 1D (Npix,) or 2D (B, Npix) cell_ids and return
|
|
369
|
+
tensors batched sur la 1ère dim (B, ...).
|
|
370
|
+
"""
|
|
371
|
+
# → cast numpy
|
|
372
|
+
if torch.is_tensor(i_cell_ids):
|
|
373
|
+
cid = i_cell_ids.detach().cpu().numpy()
|
|
374
|
+
else:
|
|
375
|
+
cid = np.asarray(i_cell_ids)
|
|
376
|
+
|
|
377
|
+
# --- 1D: pas de boucle, on calcule une fois, puis on ajoute l'axe batch
|
|
378
|
+
if cid.ndim == 1:
|
|
379
|
+
l_idx_nn, l_w_idx, l_w_w = self.make_idx_weights_from_one_cell_ids(
|
|
380
|
+
cid, polar=polar, gamma=gamma, device=device,
|
|
381
|
+
allow_extrapolation=allow_extrapolation
|
|
382
|
+
)
|
|
383
|
+
idx_nn = torch.as_tensor(l_idx_nn, device=device, dtype=torch.long)[None, ...] # (1, Npix, P)
|
|
384
|
+
w_idx = torch.as_tensor(l_w_idx, device=device, dtype=torch.long)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
|
|
385
|
+
w_w = torch.as_tensor(l_w_w, device=device, dtype=self.dtype)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
|
|
386
|
+
return idx_nn, w_idx, w_w
|
|
387
|
+
|
|
388
|
+
# --- 2D: boucle sur b, empilement en (B, ...)
|
|
389
|
+
elif cid.ndim == 2:
|
|
390
|
+
outs = [ self.make_idx_weights_from_one_cell_ids(
|
|
391
|
+
cid[k], polar=polar, gamma=gamma, device=device,
|
|
392
|
+
allow_extrapolation=allow_extrapolation)
|
|
393
|
+
for k in range(cid.shape[0]) ]
|
|
394
|
+
idx_nn = torch.as_tensor(np.stack([o[0] for o in outs], axis=0), device=device, dtype=torch.long)
|
|
395
|
+
w_idx = torch.as_tensor(np.stack([o[1] for o in outs], axis=0), device=device, dtype=torch.long)
|
|
396
|
+
w_w = torch.as_tensor(np.stack([o[2] for o in outs], axis=0), device=device, dtype=self.dtype)
|
|
397
|
+
return idx_nn, w_idx, w_w
|
|
398
|
+
|
|
399
|
+
else:
|
|
400
|
+
raise ValueError(f"Unsupported cell_ids ndim={cid.ndim}; expected 1 or 2.")
|
|
401
|
+
'''
|
|
402
|
+
def make_idx_weights_from_cell_ids(self,i_cell_ids,
|
|
403
|
+
polar=False,
|
|
404
|
+
gamma=1.0,
|
|
405
|
+
device='cuda',
|
|
406
|
+
allow_extrapolation=True):
|
|
407
|
+
if len(i_cell_ids.shape)<2:
|
|
408
|
+
cell_ids=i_cell_ids
|
|
409
|
+
n_cids=1
|
|
410
|
+
else:
|
|
411
|
+
cell_ids=i_cell_ids[0]
|
|
412
|
+
n_cids=i_cell_ids.shape[0]
|
|
413
|
+
|
|
414
|
+
idx_nn,w_idx,w_w = [],[],[]
|
|
415
|
+
|
|
416
|
+
for k in range(n_cids):
|
|
417
|
+
cell_ids=i_cell_ids[k]
|
|
418
|
+
l_idx_nn,l_w_idx,l_w_w = self.make_idx_weights_from_one_cell_ids(cell_ids,
|
|
419
|
+
polar=polar,
|
|
420
|
+
gamma=gamma,
|
|
421
|
+
device=device,
|
|
422
|
+
allow_extrapolation=allow_extrapolation)
|
|
423
|
+
idx_nn.append(l_idx_nn)
|
|
424
|
+
w_idx.append(l_w_idx)
|
|
425
|
+
w_w.append(l_w_w)
|
|
426
|
+
|
|
427
|
+
idx_nn = torch.Tensor(np.stack(idx_nn,0)).to(device=device, dtype=torch.long)
|
|
428
|
+
w_idx = torch.Tensor(np.stack(w_idx,0)).to(device=device, dtype=torch.long)
|
|
429
|
+
w_w = torch.Tensor(np.stack(w_w,0)).to(device=device, dtype=self.dtype)
|
|
430
|
+
|
|
431
|
+
return idx_nn,w_idx,w_w
|
|
432
|
+
'''
|
|
433
|
+
|
|
434
|
+
def make_idx_weights_from_one_cell_ids(self,
|
|
435
|
+
cell_ids,
|
|
436
|
+
polar=False,
|
|
437
|
+
gamma=1.0,
|
|
438
|
+
device='cuda',
|
|
439
|
+
allow_extrapolation=True):
|
|
440
|
+
|
|
441
|
+
idx_nn = self.knn_healpix_ckdtree(cell_ids,
|
|
442
|
+
self.KERNELSZ*self.KERNELSZ,
|
|
443
|
+
self.nside,
|
|
444
|
+
nest=self.nest,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
mat_pt=self.rotation_matrices_from_healpix(self.nside,cell_ids,nest=self.nest)
|
|
448
|
+
|
|
449
|
+
t,p = hp.pix2ang(self.nside,cell_ids[idx_nn],nest=self.nest)
|
|
450
|
+
|
|
451
|
+
vec_orig=hp.ang2vec(t,p)
|
|
452
|
+
|
|
453
|
+
vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
|
|
295
454
|
|
|
296
|
-
|
|
455
|
+
del vec_orig
|
|
456
|
+
del mat_pt
|
|
457
|
+
|
|
458
|
+
rotate=2*((t<np.pi/2)-0.5)[:,None]
|
|
297
459
|
if polar:
|
|
298
|
-
xx=np.cos(
|
|
299
|
-
yy=-np.sin(
|
|
460
|
+
xx=np.cos(p)[:,None]*vec_rot[:,:,0]-rotate*np.sin(p)[:,None]*vec_rot[:,:,1]
|
|
461
|
+
yy=-np.sin(p)[:,None]*vec_rot[:,:,0]-rotate*np.cos(p)[:,None]*vec_rot[:,:,1]
|
|
300
462
|
else:
|
|
301
|
-
xx=
|
|
302
|
-
yy=
|
|
463
|
+
xx=vec_rot[:,:,0]
|
|
464
|
+
yy=vec_rot[:,:,1]
|
|
465
|
+
|
|
466
|
+
del vec_rot
|
|
467
|
+
del rotate
|
|
468
|
+
del t
|
|
469
|
+
del p
|
|
303
470
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
471
|
+
w_idx,w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
|
|
472
|
+
yy*self.nside*gamma,
|
|
473
|
+
allow_extrapolation=allow_extrapolation)
|
|
474
|
+
'''
|
|
475
|
+
# calib : [Npix, K]
|
|
476
|
+
calib = np.zeros((w_idx.shape[0], w_idx.shape[2]))
|
|
477
|
+
# Hypothèses :
|
|
478
|
+
# w_idx.shape == (Npix, M, K) et w_w.shape == (Npix, M, K)
|
|
479
|
+
Npix, M, K = w_idx.shape
|
|
480
|
+
nb_cols = K
|
|
481
|
+
|
|
482
|
+
# 1) Accumulation par "bincount" avec décalage de ligne
|
|
483
|
+
row_ids = np.arange(Npix, dtype=np.int64)[:, None, None] * nb_cols
|
|
484
|
+
flat_idx = (row_ids + w_idx).ravel() # indices dans [0, Npix*9)
|
|
485
|
+
weights = w_w.ravel().astype(np.float64) # ou dtype de ton choix
|
|
486
|
+
|
|
487
|
+
calib = np.bincount(flat_idx, weights, minlength=Npix*nb_cols)\
|
|
488
|
+
.reshape(Npix, nb_cols)
|
|
489
|
+
|
|
490
|
+
# 2) Réinjection dans norm_a selon w_idx
|
|
491
|
+
norm_a = calib[np.arange(Npix)[:, None, None], w_idx]
|
|
492
|
+
|
|
493
|
+
w_w /= norm_a
|
|
494
|
+
w_w = np.clip(w_w,0.0,1.0)
|
|
307
495
|
|
|
496
|
+
w_w[np.isnan(w_w)]=0.0
|
|
497
|
+
'''
|
|
498
|
+
#del xx
|
|
499
|
+
#del yy
|
|
500
|
+
|
|
501
|
+
return idx_nn,w_idx,w_w,xx,yy
|
|
502
|
+
|
|
503
|
+
def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True,return_index=False):
|
|
504
|
+
|
|
505
|
+
idx_nn,w_idx,w_w = self.make_idx_weights_from_one_cell_ids(self.cell_ids,
|
|
506
|
+
polar=polar,
|
|
507
|
+
gamma=gamma,
|
|
508
|
+
device=device,
|
|
509
|
+
allow_extrapolation=allow_extrapolation)
|
|
510
|
+
|
|
308
511
|
# Ensure types/devices
|
|
309
|
-
self.idx_nn = torch.Tensor(
|
|
310
|
-
self.w_idx = torch.Tensor(
|
|
311
|
-
self.w_w = torch.Tensor(
|
|
512
|
+
self.idx_nn = torch.Tensor(idx_nn).to(device=device, dtype=torch.long)
|
|
513
|
+
self.w_idx = torch.Tensor(w_idx).to(device=device, dtype=torch.long)
|
|
514
|
+
self.w_w = torch.Tensor(w_w).to(device=device, dtype=self.dtype)
|
|
312
515
|
|
|
313
516
|
def _grid_index(self, xi, yi):
|
|
314
517
|
"""
|
|
@@ -410,108 +613,169 @@ class HOrientedConvol:
|
|
|
410
613
|
idx = np.stack([i00, i10, i01, i11], axis=1).astype(np.int64)
|
|
411
614
|
|
|
412
615
|
return idx, w
|
|
616
|
+
|
|
617
|
+
# --- Add inside class HOrientedConvol, just above Convol_torch ---
|
|
618
|
+
def _convol_single(self, im1: torch.Tensor, ww: torch.Tensor, cell_ids=None, nside=None):
|
|
619
|
+
"""
|
|
620
|
+
Single-sample path. im1: (1, C_i, Npix_1). Returns (1, C_o, Npix_1).
|
|
621
|
+
"""
|
|
622
|
+
if not isinstance(im1, torch.Tensor):
|
|
623
|
+
im1 = torch.as_tensor(im1, device=self.device, dtype=self.dtype)
|
|
624
|
+
if not isinstance(ww, torch.Tensor):
|
|
625
|
+
ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
|
|
626
|
+
assert im1.ndim == 3 and im1.shape[0] == 1, f"expected (1, C_i, Npix), got {tuple(im1.shape)}"
|
|
413
627
|
|
|
414
|
-
|
|
628
|
+
# Reuse the existing Convol_torch core by faking B=1 shapes.
|
|
629
|
+
# We call the existing (batched) implementation with B=1.
|
|
630
|
+
return self.Convol_torch(im1, ww, cell_ids=cell_ids, nside=nside) # returns (1, C_o, Npix_1)
|
|
631
|
+
|
|
632
|
+
# --- Replace the first lines of Convol_torch with a dispatcher ---
|
|
633
|
+
def Convol_torch(self, im, ww, cell_ids=None, nside=None):
|
|
415
634
|
"""
|
|
416
|
-
Batched KERNELSZxKERNELSZ
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
Input features per pixel for a batch of B samples.
|
|
422
|
-
ww : Tensor
|
|
423
|
-
Base mixing weights, indexed along its 'M' dimension by self.w_idx.
|
|
424
|
-
Supported shapes:
|
|
425
|
-
(C_i, C_o, M)
|
|
426
|
-
(C_i, C_o, M, S)
|
|
427
|
-
(B, C_i, C_o, M)
|
|
428
|
-
(B, C_i, C_o, M, S)
|
|
429
|
-
|
|
430
|
-
Class members (already tensors; will be aligned to im.device/dtype):
|
|
431
|
-
-------------------------------------------------------------------
|
|
432
|
-
self.idx_nn : LongTensor, shape (Npix, P)
|
|
433
|
-
For each center pixel, the P neighbor indices into the Npix axis of `im`.
|
|
434
|
-
(P = K*K for a KxK neighborhood.)
|
|
435
|
-
self.w_idx : LongTensor, shape (Npix, P) or (Npix, S, P)
|
|
436
|
-
Indices along the 'M' dimension of ww, per (center[, sector], neighbor).
|
|
437
|
-
self.w_w : Tensor, shape (Npix, P) or (Npix, S, P)
|
|
438
|
-
Additional scalar weights per neighbor (same layout as w_idx).
|
|
439
|
-
|
|
440
|
-
Returns
|
|
441
|
-
-------
|
|
442
|
-
out : Tensor, shape (B, C_o, Npix)
|
|
443
|
-
Aggregated output per center pixel for each batch sample.
|
|
635
|
+
Batched KERNELSZxKERNELSZ aggregation.
|
|
636
|
+
|
|
637
|
+
Accepts either:
|
|
638
|
+
- im: Tensor (B, C_i, Npix) with one shared or per-batch (B,Npix) cell_ids
|
|
639
|
+
- im: list/tuple of Tensors, each (C_i, Npix_b), with cell_ids a list of arrays
|
|
444
640
|
"""
|
|
445
|
-
|
|
641
|
+
import torch
|
|
642
|
+
|
|
643
|
+
# (A) Variable-length per-sample path: im is a list/tuple OR cell_ids is a list/tuple
|
|
644
|
+
if isinstance(im, (list, tuple)) or isinstance(cell_ids, (list, tuple)):
|
|
645
|
+
# Normalize to lists
|
|
646
|
+
im_list = im if isinstance(im, (list, tuple)) else [im]
|
|
647
|
+
cid_list = cell_ids if isinstance(cell_ids, (list, tuple)) else [cell_ids] * len(im_list)
|
|
648
|
+
assert len(im_list) == len(cid_list), "im list and cell_ids list must have same length"
|
|
649
|
+
|
|
650
|
+
outs = []
|
|
651
|
+
for xb, cb in zip(im_list, cid_list):
|
|
652
|
+
# xb: (C_i, Npix_b) -> (1, C_i, Npix_b)
|
|
653
|
+
if not torch.is_tensor(xb):
|
|
654
|
+
xb = torch.as_tensor(xb, device=self.device, dtype=self.dtype)
|
|
655
|
+
if xb.dim() == 2:
|
|
656
|
+
xb = xb.unsqueeze(0)
|
|
657
|
+
elif xb.dim() != 3 or xb.shape[0] != 1:
|
|
658
|
+
raise ValueError(f"Each sample must be (C,N) or (1,C,N); got {tuple(xb.shape)}")
|
|
659
|
+
|
|
660
|
+
yb = self._convol_single(xb, ww, cell_ids=cb, nside=nside) # (1, C_o, Npix_b)
|
|
661
|
+
outs.append(yb.squeeze(0)) # -> (C_o, Npix_b)
|
|
662
|
+
return outs # List[Tensor], each (C_o, Npix_b)
|
|
663
|
+
|
|
664
|
+
# (B) Standard fixed-length batched path (your current implementation)
|
|
665
|
+
# ... keep your existing Convol_torch body from here unchanged ...
|
|
666
|
+
# (paste your current function body starting from the type casting and assertions)
|
|
667
|
+
|
|
668
|
+
# ---- Basic checks / casting ----
|
|
669
|
+
if not isinstance(im, torch.Tensor):
|
|
670
|
+
im = torch.as_tensor(im, device=self.device, dtype=self.dtype)
|
|
671
|
+
if not isinstance(ww, torch.Tensor):
|
|
672
|
+
ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
|
|
673
|
+
|
|
446
674
|
assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
|
|
447
|
-
assert ww.shape[2]==self.KERNELSZ*self.KERNELSZ, f"`ww` must be (C_i, C_o, KERNELSZ*KERNELSZ), got {tuple(ww.shape)}"
|
|
448
|
-
|
|
449
675
|
B, C_i, Npix = im.shape
|
|
450
676
|
device = im.device
|
|
451
677
|
dtype = im.dtype
|
|
452
|
-
|
|
453
|
-
#
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
678
|
+
|
|
679
|
+
# ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
|
|
680
|
+
# target shapes:
|
|
681
|
+
# idx_nn_eff : (B, Npix, P)
|
|
682
|
+
# w_idx_eff : (B, Npix, S, P)
|
|
683
|
+
# w_w_eff : (B, Npix, S, P)
|
|
684
|
+
if cell_ids is not None:
|
|
685
|
+
# ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
|
|
686
|
+
# Normaliser: accepter Tensor, ndarray ou list/tuple de 1 élément (cas var-length, B=1)
|
|
687
|
+
if isinstance(cell_ids, (list, tuple)):
|
|
688
|
+
# liste d'ids (souvent longueur 1 en var-length)
|
|
689
|
+
if len(cell_ids) == 1:
|
|
690
|
+
cid = np.asarray(cell_ids[0])[None, :] # -> (1, Npix)
|
|
691
|
+
else:
|
|
692
|
+
# si jamais >1, on essaie d'empiler (doit avoir même Npix par élément)
|
|
693
|
+
cid = np.stack([np.asarray(c) for c in cell_ids], axis=0)
|
|
694
|
+
elif torch.is_tensor(cell_ids):
|
|
695
|
+
c = cell_ids.detach().cpu().numpy()
|
|
696
|
+
cid = c if c.ndim != 1 else c[None, :] # uniformiser en 2D quand B=1
|
|
697
|
+
else:
|
|
698
|
+
c = np.asarray(cell_ids)
|
|
699
|
+
cid = c if c.ndim != 1 else c[None, :]
|
|
700
|
+
|
|
701
|
+
# cid est maintenant (B, Npix)
|
|
702
|
+
idx_nn_eff, w_idx_eff, w_w_eff = self.make_idx_weights_from_cell_ids(
|
|
703
|
+
cid, nside, device=device
|
|
704
|
+
)
|
|
705
|
+
# shapes: (B, Npix, P), (B, Npix, S, P|P), (B, Npix, S, P|P)
|
|
706
|
+
P = idx_nn_eff.shape[-1]
|
|
707
|
+
S = w_idx_eff.shape[-2] if w_idx_eff.ndim == 4 else 1
|
|
708
|
+
|
|
709
|
+
# s’assurer des dtypes/devices
|
|
710
|
+
idx_nn_eff = torch.as_tensor(idx_nn_eff, device=device, dtype=torch.long)
|
|
711
|
+
w_idx_eff = torch.as_tensor(w_idx_eff, device=device, dtype=torch.long)
|
|
712
|
+
w_w_eff = torch.as_tensor(w_w_eff, device=device, dtype=dtype)
|
|
713
|
+
else:
|
|
714
|
+
# Use precomputed (shared for batch)
|
|
715
|
+
if self.w_idx is None:
|
|
716
|
+
if self.cell_ids.ndim==1:
|
|
717
|
+
l_cell=self.cell_ids[None,:]
|
|
718
|
+
else:
|
|
719
|
+
l_cell=self.cell_ids
|
|
720
|
+
|
|
721
|
+
idx_nn,w_idx,w_w = self.make_idx_weights_from_cell_ids(
|
|
722
|
+
l_cell,
|
|
723
|
+
polar=self.polar,
|
|
724
|
+
gamma=self.gamma,
|
|
725
|
+
device=self.device,
|
|
726
|
+
allow_extrapolation=self.allow_extrapolation)
|
|
727
|
+
|
|
728
|
+
self.idx_nn = idx_nn
|
|
729
|
+
self.w_idx = w_idx
|
|
730
|
+
self.w_w = w_w
|
|
731
|
+
else:
|
|
732
|
+
idx_nn = self.idx_nn # (Npix,P)
|
|
733
|
+
w_idx = self.w_idx # (Npix,P) or (Npix,S,P)
|
|
734
|
+
w_w = self.w_w # (Npix,P) or (Npix,S,P)
|
|
735
|
+
|
|
736
|
+
#assert idx_nn.ndim == 3 and idx_nn.size(1) == Npix, \
|
|
737
|
+
# f"`idx_nn` must be (B,Npix,P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
|
|
738
|
+
|
|
739
|
+
P = idx_nn.size(-1)
|
|
740
|
+
|
|
741
|
+
if w_idx.ndim == 3:
|
|
742
|
+
S = 1
|
|
743
|
+
w_idx_eff = w_idx[:, :, None, :] # (B,Npix,1,P)
|
|
744
|
+
w_w_eff = w_w[:, :, None, :] # (B,Npix,1,P)
|
|
745
|
+
elif w_idx.ndim == 4:
|
|
746
|
+
S = w_idx.size(2)
|
|
747
|
+
w_idx_eff = w_idx # (B,Npix,S,P)
|
|
748
|
+
w_w_eff = w_w # (B,Npix,S,P)
|
|
749
|
+
else:
|
|
750
|
+
raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
|
|
751
|
+
idx_nn_eff = idx_nn # (B,Npix,P)
|
|
752
|
+
|
|
753
|
+
# ---- 1) Gather neighbor values from im along Npix -> (B, C_i, Npix, P)
|
|
465
754
|
rim = torch.take_along_dim(
|
|
466
|
-
im.unsqueeze(-1),
|
|
467
|
-
|
|
755
|
+
im.unsqueeze(-1), # (B, C_i, Npix, 1)
|
|
756
|
+
idx_nn_eff[:, None, :, :], # (B, 1, Npix, P)
|
|
468
757
|
dim=2
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
# ---- 2) Normalize
|
|
472
|
-
# Target layout: (Npix, S, P)
|
|
473
|
-
if w_idx.ndim == 2:
|
|
474
|
-
# (Npix, P) -> add sector dim S=1
|
|
475
|
-
assert w_idx.size(0) == Npix and w_idx.size(1) == P
|
|
476
|
-
w_idx_eff = w_idx.unsqueeze(1) # (Npix, 1, P)
|
|
477
|
-
w_w_eff = w_w.unsqueeze(1) # (Npix, 1, P)
|
|
478
|
-
S = 1
|
|
479
|
-
elif w_idx.ndim == 3:
|
|
480
|
-
# (Npix, S, P)
|
|
481
|
-
Npix_, S, P_ = w_idx.shape
|
|
482
|
-
assert Npix_ == Npix and P_ == P, \
|
|
483
|
-
f"`w_idx` must be (Npix,S,P) with Npix={Npix}, P={P}, got {tuple(w_idx.shape)}"
|
|
484
|
-
assert w_w.shape == w_idx.shape, "`w_w` must match `w_idx` shape"
|
|
485
|
-
w_idx_eff = w_idx
|
|
486
|
-
w_w_eff = w_w
|
|
487
|
-
else:
|
|
488
|
-
raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
|
|
489
|
-
|
|
490
|
-
# ---- 3) Normalize ww to (B, C_i, C_o, M, S) for uniform gather ----
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
# ---- 2) Normalize ww to (B, C_i, C_o, M, S)
|
|
491
761
|
if ww.ndim == 3:
|
|
492
|
-
# (C_i, C_o, M) -> (B, C_i, C_o, M, S)
|
|
493
762
|
C_i_w, C_o, M = ww.shape
|
|
494
763
|
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
495
764
|
ww_eff = ww.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, -1, S)
|
|
496
|
-
|
|
497
765
|
elif ww.ndim == 4:
|
|
498
|
-
# Could be (C_i, C_o, M, S) or (B, C_i, C_o, M)
|
|
499
766
|
if ww.shape[0] == C_i and ww.shape[1] != C_i:
|
|
500
|
-
# (C_i, C_o, M, S)
|
|
767
|
+
# (C_i, C_o, M, S)
|
|
501
768
|
C_i_w, C_o, M, S_w = ww.shape
|
|
502
769
|
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
503
770
|
assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
|
|
504
771
|
ww_eff = ww.unsqueeze(0).expand(B, -1, -1, -1, -1)
|
|
505
772
|
elif ww.shape[0] == B:
|
|
506
|
-
# (B, C_i, C_o, M)
|
|
773
|
+
# (B, C_i, C_o, M)
|
|
507
774
|
_, C_i_w, C_o, M = ww.shape
|
|
508
775
|
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
509
776
|
ww_eff = ww.unsqueeze(-1).expand(-1, -1, -1, -1, S)
|
|
510
777
|
else:
|
|
511
|
-
raise ValueError(
|
|
512
|
-
f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)"
|
|
513
|
-
)
|
|
514
|
-
|
|
778
|
+
raise ValueError(f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)")
|
|
515
779
|
elif ww.ndim == 5:
|
|
516
780
|
# (B, C_i, C_o, M, S)
|
|
517
781
|
assert ww.shape[0] == B and ww.shape[1] == C_i, "ww batch/C_i mismatch"
|
|
@@ -520,27 +784,150 @@ class HOrientedConvol:
|
|
|
520
784
|
ww_eff = ww
|
|
521
785
|
else:
|
|
522
786
|
raise ValueError(f"Unsupported ww shape {tuple(ww.shape)}")
|
|
523
|
-
|
|
524
|
-
#
|
|
525
|
-
|
|
787
|
+
|
|
788
|
+
# --- Sanitize shapes: ensure w_idx_eff / w_w_eff == (B, Npix, S, P)
|
|
789
|
+
|
|
790
|
+
# ---- 3) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
|
|
791
|
+
idx_exp = w_idx_eff[:, None, None, :, :, :] # (B,1,1,Npix,S,P)
|
|
526
792
|
rw = torch.take_along_dim(
|
|
527
|
-
ww_eff.unsqueeze(-1),
|
|
528
|
-
idx_exp,
|
|
529
|
-
dim=3
|
|
793
|
+
ww_eff.unsqueeze(-1), # (B,C_i,C_o,M,S,1)
|
|
794
|
+
idx_exp, # (B,1,1,Npix,S,P)
|
|
795
|
+
dim=3 # gather along M
|
|
530
796
|
) # -> (B, C_i, C_o, Npix, S, P)
|
|
797
|
+
# ---- 4) Apply extra neighbor weights ----
|
|
798
|
+
rw = rw * w_w_eff[:, None, None, :, :, :] # (B, C_i, C_o, Npix, S, P)
|
|
799
|
+
# ---- 5) Combine neighbor values and weights ----
|
|
800
|
+
rim_exp = rim[:, :, None, :, None, :] # (B, C_i, 1, Npix, 1, P)
|
|
801
|
+
out_ci = (rim_exp * rw).sum(dim=-1) # sum over P -> (B, C_i, C_o, Npix, S)
|
|
802
|
+
out_ci = out_ci.sum(dim=-1) # sum over S -> (B, C_i, C_o, Npix)
|
|
803
|
+
out = out_ci.sum(dim=1) # sum over C_i -> (B, C_o, Npix)
|
|
804
|
+
|
|
805
|
+
return out
|
|
531
806
|
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
807
|
+
def _to_numpy_1d(self, ids):
|
|
808
|
+
"""Return a 1D numpy array of int64 for a single set of cell ids."""
|
|
809
|
+
import numpy as np, torch
|
|
810
|
+
if isinstance(ids, np.ndarray):
|
|
811
|
+
return ids.reshape(-1).astype(np.int64, copy=False)
|
|
812
|
+
if torch.is_tensor(ids):
|
|
813
|
+
return ids.detach().cpu().to(torch.long).view(-1).numpy()
|
|
814
|
+
# python list/tuple of ints
|
|
815
|
+
return np.asarray(ids, dtype=np.int64).reshape(-1)
|
|
816
|
+
|
|
817
|
+
def _is_varlength_batch(self, ids):
|
|
818
|
+
"""
|
|
819
|
+
True if ids is a list/tuple of per-sample id arrays (var-length batch).
|
|
820
|
+
False if ids is a single array/tensor of ids (shared for whole batch).
|
|
821
|
+
"""
|
|
822
|
+
import numpy as np, torch
|
|
823
|
+
if isinstance(ids, (list, tuple)):
|
|
824
|
+
return True
|
|
825
|
+
if isinstance(ids, np.ndarray) and ids.ndim == 2:
|
|
826
|
+
# This would be a dense (B, Npix) matrix -> NOT var-length list
|
|
827
|
+
return False
|
|
828
|
+
if torch.is_tensor(ids) and ids.dim() == 2:
|
|
829
|
+
return False
|
|
830
|
+
return False
|
|
831
|
+
|
|
832
|
+
def Down(self, im, cell_ids=None, nside=None,max_poll=False):
|
|
833
|
+
"""
|
|
834
|
+
If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
|
|
835
|
+
If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
|
|
836
|
+
"""
|
|
837
|
+
if self.f is None:
|
|
838
|
+
if self.dtype==torch.float64:
|
|
839
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
840
|
+
else:
|
|
841
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
842
|
+
|
|
843
|
+
if cell_ids is None:
|
|
844
|
+
dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=False)
|
|
845
|
+
return dim,cdim
|
|
846
|
+
|
|
847
|
+
if nside is None:
|
|
848
|
+
nside = self.nside
|
|
849
|
+
|
|
850
|
+
# var-length mode: list/tuple of ids, one per sample
|
|
851
|
+
if self._is_varlength_batch(cell_ids):
|
|
852
|
+
outs, outs_ids = [], []
|
|
853
|
+
B = len(cell_ids)
|
|
854
|
+
for b in range(B):
|
|
855
|
+
cid_b = self._to_numpy_1d(cell_ids[b])
|
|
856
|
+
# extraire le bon échantillon d'`im`
|
|
857
|
+
if torch.is_tensor(im):
|
|
858
|
+
xb = im[b:b+1] # (1, C, N_b)
|
|
859
|
+
yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
860
|
+
outs.append(yb.squeeze(0)) # (C, N_b')
|
|
861
|
+
else:
|
|
862
|
+
# si im est déjà une liste de (C, N_b)
|
|
863
|
+
xb = im[b]
|
|
864
|
+
yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
865
|
+
outs.append(yb.squeeze(0))
|
|
866
|
+
outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
|
|
867
|
+
return outs, outs_ids
|
|
868
|
+
|
|
869
|
+
# grille commune (un seul vecteur d'ids)
|
|
870
|
+
cid = self._to_numpy_1d(cell_ids)
|
|
871
|
+
return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
|
|
872
|
+
|
|
873
|
+
def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
|
|
874
|
+
"""
|
|
875
|
+
If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
|
|
876
|
+
If they are lists (var-length per sample) -> return list[Tensor].
|
|
877
|
+
"""
|
|
878
|
+
if self.f is None:
|
|
879
|
+
if self.dtype==torch.float64:
|
|
880
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
881
|
+
else:
|
|
882
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
883
|
+
|
|
884
|
+
if cell_ids is None:
|
|
885
|
+
dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
|
|
886
|
+
return dim
|
|
887
|
+
|
|
888
|
+
if nside is None:
|
|
889
|
+
nside = self.nside
|
|
890
|
+
|
|
891
|
+
# var-length: listes parallèles
|
|
892
|
+
if self._is_varlength_batch(cell_ids):
|
|
893
|
+
assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
|
|
894
|
+
"In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
|
|
895
|
+
outs = []
|
|
896
|
+
B = len(cell_ids)
|
|
897
|
+
for b in range(B):
|
|
898
|
+
cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
|
|
899
|
+
ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
|
|
900
|
+
if torch.is_tensor(im):
|
|
901
|
+
xb = im[b:b+1] # (1, C, N_b_coarse)
|
|
902
|
+
yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
|
|
903
|
+
o_cell_ids=ocid_b, force_init_index=True)
|
|
904
|
+
outs.append(yb.squeeze(0)) # (C, N_b_fine)
|
|
905
|
+
else:
|
|
906
|
+
xb = im[b] # (C, N_b_coarse)
|
|
907
|
+
yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
|
|
908
|
+
o_cell_ids=ocid_b, force_init_index=True)
|
|
909
|
+
outs.append(yb.squeeze(0))
|
|
910
|
+
return outs
|
|
911
|
+
|
|
912
|
+
# grille commune
|
|
913
|
+
cid = self._to_numpy_1d(cell_ids)
|
|
914
|
+
ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
|
|
915
|
+
return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
|
|
916
|
+
o_cell_ids=ocid, force_init_index=True)
|
|
917
|
+
|
|
918
|
+
def to_tensor(self,x):
|
|
919
|
+
if self.f is None:
|
|
920
|
+
if self.dtype==torch.float64:
|
|
921
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
922
|
+
else:
|
|
923
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
924
|
+
return self.f.backend.bk_cast(x)
|
|
542
925
|
|
|
543
|
-
|
|
926
|
+
def to_numpy(self,x):
|
|
927
|
+
if isinstance(x,np.ndarray):
|
|
928
|
+
return x
|
|
929
|
+
return x.cpu().numpy()
|
|
930
|
+
|
|
544
931
|
|
|
545
932
|
|
|
546
933
|
|