foscat 2025.11.1__py3-none-any.whl → 2026.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/FoCUS.py +71 -16
- foscat/SphereDownGeo.py +380 -0
- foscat/SphereUpGeo.py +175 -0
- foscat/SphericalStencil.py +27 -246
- foscat/alm_loc.py +270 -0
- foscat/healpix_vit_torch-old.py +658 -0
- foscat/scat_cov.py +24 -24
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/METADATA +1 -69
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/RECORD +12 -8
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/WHEEL +1 -1
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/top_level.txt +0 -0
foscat/SphereUpGeo.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from foscat.SphereDownGeo import SphereDownGeo
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SphereUpGeo(nn.Module):
|
|
9
|
+
"""Geometric HEALPix upsampling operator using the transpose of SphereDownGeo.
|
|
10
|
+
|
|
11
|
+
`cell_ids_out` (coarse pixels at nside_out, NESTED) is mandatory.
|
|
12
|
+
Forward expects x of shape [B, C, K_out] aligned with that order.
|
|
13
|
+
Output is a full fine-grid map [B, C, N_in] at nside_in = 2*nside_out.
|
|
14
|
+
|
|
15
|
+
Normalization (diagonal corrections):
|
|
16
|
+
- up_norm='adjoint': x_up = M^T x
|
|
17
|
+
- up_norm='col_l1': x_up = (M^T x) / col_sum, col_sum[i] = sum_k M[k,i]
|
|
18
|
+
- up_norm='diag_l2': x_up = (M^T x) / col_l2, col_l2[i] = sum_k M[k,i]^2
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
nside_out: int,
|
|
24
|
+
cell_ids_out,
|
|
25
|
+
radius_deg: float | None = None,
|
|
26
|
+
sigma_deg: float | None = None,
|
|
27
|
+
weight_norm: str = "l1",
|
|
28
|
+
up_norm: str = "col_l1",
|
|
29
|
+
eps: float = 1e-12,
|
|
30
|
+
device=None,
|
|
31
|
+
dtype=torch.float32,
|
|
32
|
+
):
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
if cell_ids_out is None:
|
|
36
|
+
raise ValueError("cell_ids_out is mandatory (1D list/np/tensor of coarse HEALPix ids at nside_out).")
|
|
37
|
+
|
|
38
|
+
if device is None:
|
|
39
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
40
|
+
self.device = device
|
|
41
|
+
self.dtype = dtype
|
|
42
|
+
|
|
43
|
+
self.nside_out = int(nside_out)
|
|
44
|
+
assert (self.nside_out & (self.nside_out - 1)) == 0, "nside_out must be a power of 2."
|
|
45
|
+
self.nside_in = self.nside_out * 2
|
|
46
|
+
|
|
47
|
+
self.N_out = 12 * self.nside_out * self.nside_out
|
|
48
|
+
self.N_in = 12 * self.nside_in * self.nside_in
|
|
49
|
+
|
|
50
|
+
up_norm = str(up_norm).lower().strip()
|
|
51
|
+
if up_norm not in ("adjoint", "col_l1", "diag_l2"):
|
|
52
|
+
raise ValueError("up_norm must be 'adjoint', 'col_l1', or 'diag_l2'.")
|
|
53
|
+
self.up_norm = up_norm
|
|
54
|
+
self.eps = float(eps)
|
|
55
|
+
|
|
56
|
+
# Coarse ids in user-provided order (must be unique for alignment)
|
|
57
|
+
if isinstance(cell_ids_out, torch.Tensor):
|
|
58
|
+
cell_ids_out_np = cell_ids_out.detach().cpu().numpy().astype(np.int64)
|
|
59
|
+
else:
|
|
60
|
+
cell_ids_out_np = np.asarray(cell_ids_out, dtype=np.int64)
|
|
61
|
+
|
|
62
|
+
if cell_ids_out_np.ndim != 1:
|
|
63
|
+
raise ValueError("cell_ids_out must be 1D")
|
|
64
|
+
if cell_ids_out_np.size == 0:
|
|
65
|
+
raise ValueError("cell_ids_out must be non-empty")
|
|
66
|
+
if cell_ids_out_np.min() < 0 or cell_ids_out_np.max() >= self.N_out:
|
|
67
|
+
raise ValueError("cell_ids_out contains out-of-bounds ids for this nside_out")
|
|
68
|
+
if np.unique(cell_ids_out_np).size != cell_ids_out_np.size:
|
|
69
|
+
raise ValueError("cell_ids_out must not contain duplicates (order matters for alignment).")
|
|
70
|
+
|
|
71
|
+
self.cell_ids_out_np = cell_ids_out_np
|
|
72
|
+
self.K_out = int(cell_ids_out_np.size)
|
|
73
|
+
self.register_buffer("cell_ids_out_t", torch.as_tensor(cell_ids_out_np, dtype=torch.long, device=self.device))
|
|
74
|
+
|
|
75
|
+
# Build the FULL down operator at fine resolution (nside_in -> nside_out)
|
|
76
|
+
tmp_down = SphereDownGeo(
|
|
77
|
+
nside_in=self.nside_in,
|
|
78
|
+
mode="smooth",
|
|
79
|
+
radius_deg=radius_deg,
|
|
80
|
+
sigma_deg=sigma_deg,
|
|
81
|
+
weight_norm=weight_norm,
|
|
82
|
+
device=self.device,
|
|
83
|
+
dtype=self.dtype,
|
|
84
|
+
use_csr=False,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
M_down_full = torch.sparse_coo_tensor(
|
|
88
|
+
tmp_down.M.indices(),
|
|
89
|
+
tmp_down.M.values(),
|
|
90
|
+
size=(tmp_down.N_out, tmp_down.N_in),
|
|
91
|
+
device=self.device,
|
|
92
|
+
dtype=self.dtype,
|
|
93
|
+
).coalesce()
|
|
94
|
+
|
|
95
|
+
# Extract ONLY the requested coarse rows, in the provided order.
|
|
96
|
+
# We do this on CPU with numpy for simplicity and speed at init.
|
|
97
|
+
idx = M_down_full.indices().cpu().numpy()
|
|
98
|
+
vals = M_down_full.values().cpu().numpy()
|
|
99
|
+
rows = idx[0]
|
|
100
|
+
cols = idx[1]
|
|
101
|
+
|
|
102
|
+
# Map original row id -> new row position [0..K_out-1]
|
|
103
|
+
row_map = {int(r): i for i, r in enumerate(cell_ids_out_np.tolist())}
|
|
104
|
+
mask = np.fromiter((r in row_map for r in rows), dtype=bool, count=rows.size)
|
|
105
|
+
|
|
106
|
+
rows_sel = rows[mask]
|
|
107
|
+
cols_sel = cols[mask]
|
|
108
|
+
vals_sel = vals[mask]
|
|
109
|
+
|
|
110
|
+
new_rows = np.fromiter((row_map[int(r)] for r in rows_sel), dtype=np.int64, count=rows_sel.size)
|
|
111
|
+
|
|
112
|
+
M_down_sub = torch.sparse_coo_tensor(
|
|
113
|
+
torch.as_tensor(np.stack([new_rows, cols_sel], axis=0), dtype=torch.long),
|
|
114
|
+
torch.as_tensor(vals_sel, dtype=self.dtype),
|
|
115
|
+
size=(self.K_out, self.N_in),
|
|
116
|
+
device=self.device,
|
|
117
|
+
dtype=self.dtype,
|
|
118
|
+
).coalesce()
|
|
119
|
+
|
|
120
|
+
# Store M^T (sparse) so forward is just sparse.mm
|
|
121
|
+
M_up = self._transpose_sparse(M_down_sub) # [N_in, K_out]
|
|
122
|
+
self.register_buffer("M_indices", M_up.indices())
|
|
123
|
+
self.register_buffer("M_values", M_up.values())
|
|
124
|
+
self.M_size = M_up.size()
|
|
125
|
+
|
|
126
|
+
# Diagonal normalizers (length N_in), based on the selected coarse rows only
|
|
127
|
+
idx_sub = M_down_sub.indices()
|
|
128
|
+
vals_sub = M_down_sub.values()
|
|
129
|
+
fine_cols = idx_sub[1]
|
|
130
|
+
|
|
131
|
+
col_sum = torch.zeros(self.N_in, device=self.device, dtype=self.dtype)
|
|
132
|
+
col_l2 = torch.zeros(self.N_in, device=self.device, dtype=self.dtype)
|
|
133
|
+
col_sum.scatter_add_(0, fine_cols, vals_sub)
|
|
134
|
+
col_l2.scatter_add_(0, fine_cols, vals_sub * vals_sub)
|
|
135
|
+
|
|
136
|
+
self.register_buffer("col_sum", col_sum)
|
|
137
|
+
self.register_buffer("col_l2", col_l2)
|
|
138
|
+
|
|
139
|
+
# Fine ids (full sphere)
|
|
140
|
+
self.register_buffer("cell_ids_in_t", torch.arange(self.N_in, dtype=torch.long, device=self.device))
|
|
141
|
+
|
|
142
|
+
self.M_T = torch.sparse_coo_tensor(
|
|
143
|
+
self.M_indices.to(device=self.device),
|
|
144
|
+
self.M_values.to(device=self.device, dtype=self.dtype),
|
|
145
|
+
size=self.M_size,
|
|
146
|
+
device=self.device,
|
|
147
|
+
dtype=self.dtype,
|
|
148
|
+
).coalesce().to_sparse_csr().to(self.device)
|
|
149
|
+
|
|
150
|
+
@staticmethod
|
|
151
|
+
def _transpose_sparse(M: torch.Tensor) -> torch.Tensor:
|
|
152
|
+
M = M.coalesce()
|
|
153
|
+
idx = M.indices()
|
|
154
|
+
vals = M.values()
|
|
155
|
+
R, C = M.size()
|
|
156
|
+
idx_T = torch.stack([idx[1], idx[0]], dim=0)
|
|
157
|
+
return torch.sparse_coo_tensor(idx_T, vals, size=(C, R), device=M.device, dtype=M.dtype).coalesce()
|
|
158
|
+
|
|
159
|
+
def forward(self, x: torch.Tensor):
|
|
160
|
+
"""x: [B, C, K_out] -> x_up: [B, C, N_in]."""
|
|
161
|
+
B, C, K_out = x.shape
|
|
162
|
+
assert K_out == self.K_out, f"Expected K_out={self.K_out}, got {K_out}"
|
|
163
|
+
|
|
164
|
+
x_bc = x.reshape(B * C, K_out)
|
|
165
|
+
x_up_bc_T = torch.sparse.mm(self.M_T, x_bc.T) # [N_in, B*C]
|
|
166
|
+
x_up = x_up_bc_T.T.reshape(B, C, self.N_in) # [B, C, N_in]
|
|
167
|
+
|
|
168
|
+
if self.up_norm == "col_l1":
|
|
169
|
+
denom = self.col_sum.to(device=x.device, dtype=x.dtype).clamp_min(self.eps)
|
|
170
|
+
x_up = x_up / denom.view(1, 1, -1)
|
|
171
|
+
elif self.up_norm == "diag_l2":
|
|
172
|
+
denom = self.col_l2.to(device=x.device, dtype=x.dtype).clamp_min(self.eps)
|
|
173
|
+
x_up = x_up / denom.view(1, 1, -1)
|
|
174
|
+
|
|
175
|
+
return x_up, self.cell_ids_in_t.to(device=x.device)
|
foscat/SphericalStencil.py
CHANGED
|
@@ -2,13 +2,8 @@
|
|
|
2
2
|
# Author: J.-M. Delouis
|
|
3
3
|
import numpy as np
|
|
4
4
|
import healpy as hp
|
|
5
|
-
import foscat.scat_cov as sc
|
|
6
5
|
import torch
|
|
7
6
|
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
import healpy as hp
|
|
11
|
-
|
|
12
7
|
|
|
13
8
|
class SphericalStencil:
|
|
14
9
|
"""
|
|
@@ -61,8 +56,7 @@ class SphericalStencil:
|
|
|
61
56
|
device=None,
|
|
62
57
|
dtype=None,
|
|
63
58
|
n_gauges=1,
|
|
64
|
-
gauge_type='
|
|
65
|
-
scat_op=None,
|
|
59
|
+
gauge_type='phi',
|
|
66
60
|
):
|
|
67
61
|
assert kernel_sz >= 1 and int(kernel_sz) == kernel_sz
|
|
68
62
|
assert kernel_sz % 2 == 1, "kernel_sz must be odd"
|
|
@@ -75,10 +69,6 @@ class SphericalStencil:
|
|
|
75
69
|
self.gauge_type=gauge_type
|
|
76
70
|
|
|
77
71
|
self.nest = bool(nest)
|
|
78
|
-
if scat_op is None:
|
|
79
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ)
|
|
80
|
-
else:
|
|
81
|
-
self.f=scat_op
|
|
82
72
|
|
|
83
73
|
# Torch defaults
|
|
84
74
|
if device is None:
|
|
@@ -354,10 +344,27 @@ class SphericalStencil:
|
|
|
354
344
|
# --- build the local (P,3) stencil once on device
|
|
355
345
|
P = self.P
|
|
356
346
|
vec_np = np.zeros((P, 3), dtype=float)
|
|
357
|
-
grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2)
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
347
|
+
grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2)
|
|
348
|
+
|
|
349
|
+
# NEW: angular offsets
|
|
350
|
+
xx,yy=np.meshgrid(grid,grid)
|
|
351
|
+
s=1.0 # could be modified
|
|
352
|
+
alpha_pix = hp.nside2resol(self.nside, arcmin=False) # ~ taille angulaire typique
|
|
353
|
+
dtheta = (np.sqrt(xx**2+yy**2) * alpha_pix * s).ravel()
|
|
354
|
+
dphi = (np.arctan2(yy,xx)).ravel()
|
|
355
|
+
# local spherical displacement
|
|
356
|
+
# convert to unit vectors
|
|
357
|
+
x = np.sin(dtheta) * np.cos(dphi)
|
|
358
|
+
y = np.sin(dtheta) * np.sin(dphi)
|
|
359
|
+
z = np.cos(dtheta)
|
|
360
|
+
#print(self.nside*x.reshape(self.KERNELSZ,self.KERNELSZ))
|
|
361
|
+
#print(self.nside*y.reshape(self.KERNELSZ,self.KERNELSZ))
|
|
362
|
+
#print(self.nside*z.reshape(self.KERNELSZ,self.KERNELSZ))
|
|
363
|
+
vec_np = np.stack([x, y, z], axis=-1)
|
|
364
|
+
|
|
365
|
+
#vec_np[:, 0] = np.tile(grid, self.KERNELSZ)
|
|
366
|
+
#vec_np[:, 1] = np.repeat(grid, self.KERNELSZ)
|
|
367
|
+
#vec_np[:, 2] = 1.0 - np.sqrt(vec_np[:, 0]**2 + vec_np[:, 1]**2)
|
|
361
368
|
vec_t = torch.as_tensor(vec_np, device=self.device, dtype=self.dtype) # (P,3)
|
|
362
369
|
|
|
363
370
|
# --- rotation matrices for all targets & gauges: (K,G,3,3)
|
|
@@ -371,7 +378,7 @@ class SphericalStencil:
|
|
|
371
378
|
th, ph, alpha, G=self.G, gauge_cosmo=(self.gauge_type=='cosmo'),
|
|
372
379
|
device=self.device, dtype=self.dtype
|
|
373
380
|
) # shape (K,G,3,3)
|
|
374
|
-
|
|
381
|
+
|
|
375
382
|
# --- rotate stencil for each (target, gauge): (K,G,P,3)
|
|
376
383
|
# einsum over local stencil (P,3) with rotation (K,G,3,3)
|
|
377
384
|
rotated = torch.einsum('kgij,pj->kgpi', R_t, vec_t) # (K,G,P,3)
|
|
@@ -568,119 +575,6 @@ class SphericalStencil:
|
|
|
568
575
|
self.dtype = dtype
|
|
569
576
|
|
|
570
577
|
|
|
571
|
-
'''
|
|
572
|
-
def bind_support_torch_multi(self, ids_sorted_np, *, device=None, dtype=None):
|
|
573
|
-
"""
|
|
574
|
-
Multi-gauge sparse binding (Step B).
|
|
575
|
-
Uses self.idx_t_multi / self.w_t_multi prepared by prepare_torch(..., G>1)
|
|
576
|
-
and builds, for each gauge g, (pos_safe, w_norm, present).
|
|
577
|
-
|
|
578
|
-
Parameters
|
|
579
|
-
----------
|
|
580
|
-
ids_sorted_np : np.ndarray (K,)
|
|
581
|
-
Sorted pixel ids for available samples (matches the last axis of your data).
|
|
582
|
-
device, dtype : torch device/dtype for the produced mapping tensors.
|
|
583
|
-
|
|
584
|
-
Side effects
|
|
585
|
-
------------
|
|
586
|
-
Sets:
|
|
587
|
-
- self.ids_sorted_np : (K,)
|
|
588
|
-
- self.pos_safe_t_multi : (G, 4, K*P) LongTensor
|
|
589
|
-
- self.w_norm_t_multi : (G, 4, K*P) Tensor
|
|
590
|
-
- self.present_t_multi : (G, 4, K*P) BoolTensor
|
|
591
|
-
- (and mirrors device/dtype in self.device/self.dtype)
|
|
592
|
-
"""
|
|
593
|
-
assert hasattr(self, 'idx_t_multi') and self.idx_t_multi is not None, \
|
|
594
|
-
"Call prepare_torch(..., G>0) before bind_support_torch_multi(...)"
|
|
595
|
-
assert hasattr(self, 'w_t_multi') and self.w_t_multi is not None
|
|
596
|
-
|
|
597
|
-
if device is None: device = self.device
|
|
598
|
-
if dtype is None: dtype = self.dtype
|
|
599
|
-
|
|
600
|
-
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64).reshape(-1)
|
|
601
|
-
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
|
|
602
|
-
|
|
603
|
-
G, _, M = self.idx_t_multi.shape
|
|
604
|
-
K = self.Kb
|
|
605
|
-
P = self.P
|
|
606
|
-
assert M == K*P, "idx_t_multi second axis must have K*P columns"
|
|
607
|
-
|
|
608
|
-
pos_list, present_list, wnorm_list = [], [], []
|
|
609
|
-
|
|
610
|
-
for g in range(G):
|
|
611
|
-
idx = self.idx_t_multi[g].to(device=device, dtype=torch.long) # (4, M)
|
|
612
|
-
w = self.w_t_multi[g].to(device=device, dtype=dtype) # (4, M)
|
|
613
|
-
|
|
614
|
-
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
|
|
615
|
-
in_range = pos < ids_sorted.numel()
|
|
616
|
-
cmp_vals = torch.full_like(idx, -1)
|
|
617
|
-
cmp_vals[in_range] = ids_sorted[pos[in_range]]
|
|
618
|
-
present = (cmp_vals == idx)
|
|
619
|
-
|
|
620
|
-
# normalize weights per column after masking
|
|
621
|
-
w = w * present
|
|
622
|
-
colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
|
|
623
|
-
w_norm = w / colsum
|
|
624
|
-
|
|
625
|
-
pos_safe = torch.where(present, pos, torch.zeros_like(pos))
|
|
626
|
-
|
|
627
|
-
pos_list.append(pos_safe)
|
|
628
|
-
present_list.append(present)
|
|
629
|
-
wnorm_list.append(w_norm)
|
|
630
|
-
|
|
631
|
-
self.pos_safe_t_multi = torch.stack(pos_list, dim=0) # (G, 4, M)
|
|
632
|
-
self.present_t_multi = torch.stack(present_list, dim=0) # (G, 4, M)
|
|
633
|
-
self.w_norm_t_multi = torch.stack(wnorm_list, dim=0) # (G, 4, M)
|
|
634
|
-
|
|
635
|
-
# mirror runtime placement
|
|
636
|
-
self.device = device
|
|
637
|
-
self.dtype = dtype
|
|
638
|
-
|
|
639
|
-
# ------------------------------------------------------------------
|
|
640
|
-
# Step B: bind support Torch
|
|
641
|
-
# ------------------------------------------------------------------
|
|
642
|
-
def bind_support_torch(self, ids_sorted_np, *, device=None, dtype=None):
|
|
643
|
-
"""
|
|
644
|
-
Map HEALPix neighbor indices (from Step A) to actual data samples
|
|
645
|
-
sorted by pixel id. Produces pos_safe and normalized weights.
|
|
646
|
-
|
|
647
|
-
Parameters
|
|
648
|
-
----------
|
|
649
|
-
ids_sorted_np : np.ndarray (K,)
|
|
650
|
-
Sorted pixel ids for available data.
|
|
651
|
-
device, dtype : Torch device/dtype for results.
|
|
652
|
-
"""
|
|
653
|
-
if device is None:
|
|
654
|
-
device = self.device
|
|
655
|
-
if dtype is None:
|
|
656
|
-
dtype = self.dtype
|
|
657
|
-
|
|
658
|
-
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64)
|
|
659
|
-
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
|
|
660
|
-
|
|
661
|
-
idx = self.idx_t.to(device=device, dtype=torch.long)
|
|
662
|
-
w = self.w_t.to(device=device, dtype=dtype)
|
|
663
|
-
|
|
664
|
-
M = self.Kb * self.P
|
|
665
|
-
idx = idx.view(4, M)
|
|
666
|
-
w = w.view(4, M)
|
|
667
|
-
|
|
668
|
-
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
|
|
669
|
-
in_range = pos < ids_sorted.shape[0]
|
|
670
|
-
cmp_vals = torch.full_like(idx, -1)
|
|
671
|
-
cmp_vals[in_range] = ids_sorted[pos[in_range]]
|
|
672
|
-
present = (cmp_vals == idx)
|
|
673
|
-
|
|
674
|
-
w = w * present
|
|
675
|
-
colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
|
|
676
|
-
w_norm = w / colsum
|
|
677
|
-
|
|
678
|
-
self.pos_safe_t = torch.where(present, pos, torch.zeros_like(pos))
|
|
679
|
-
self.w_norm_t = w_norm
|
|
680
|
-
self.present_t = present
|
|
681
|
-
self.device = device
|
|
682
|
-
self.dtype = dtype
|
|
683
|
-
'''
|
|
684
578
|
# ------------------------------------------------------------------
|
|
685
579
|
# Step C: apply convolution (already Torch in your code)
|
|
686
580
|
# ------------------------------------------------------------------
|
|
@@ -1215,7 +1109,7 @@ class SphericalStencil:
|
|
|
1215
1109
|
vals = torch.cat(vals_all, dim=0)
|
|
1216
1110
|
|
|
1217
1111
|
|
|
1218
|
-
indices = torch.stack([cols, rows], dim=0)
|
|
1112
|
+
indices = torch.stack([cols, rows], dim=0)
|
|
1219
1113
|
|
|
1220
1114
|
if return_sparse_tensor:
|
|
1221
1115
|
M = torch.sparse_coo_tensor(indices, vals, size=shape, device=device, dtype=k_dtype).coalesce()
|
|
@@ -1224,123 +1118,10 @@ class SphericalStencil:
|
|
|
1224
1118
|
return vals, indices, shape
|
|
1225
1119
|
|
|
1226
1120
|
|
|
1227
|
-
def _to_numpy_1d(self, ids):
|
|
1228
|
-
"""Return a 1D numpy array of int64 for a single set of cell ids."""
|
|
1229
|
-
import numpy as np, torch
|
|
1230
|
-
if isinstance(ids, np.ndarray):
|
|
1231
|
-
return ids.reshape(-1).astype(np.int64, copy=False)
|
|
1232
|
-
if torch.is_tensor(ids):
|
|
1233
|
-
return ids.detach().cpu().to(torch.long).view(-1).numpy()
|
|
1234
|
-
# python list/tuple of ints
|
|
1235
|
-
return np.asarray(ids, dtype=np.int64).reshape(-1)
|
|
1236
|
-
|
|
1237
|
-
def _is_varlength_batch(self, ids):
|
|
1238
|
-
"""
|
|
1239
|
-
True if ids is a list/tuple of per-sample id arrays (var-length batch).
|
|
1240
|
-
False if ids is a single array/tensor of ids (shared for whole batch).
|
|
1241
|
-
"""
|
|
1242
|
-
import numpy as np, torch
|
|
1243
|
-
if isinstance(ids, (list, tuple)):
|
|
1244
|
-
return True
|
|
1245
|
-
if isinstance(ids, np.ndarray) and ids.ndim == 2:
|
|
1246
|
-
# This would be a dense (B, Npix) matrix -> NOT var-length list
|
|
1247
|
-
return False
|
|
1248
|
-
if torch.is_tensor(ids) and ids.dim() == 2:
|
|
1249
|
-
return False
|
|
1250
|
-
return False
|
|
1251
|
-
|
|
1252
|
-
def Down(self, im, cell_ids=None, nside=None,max_poll=False):
|
|
1253
|
-
"""
|
|
1254
|
-
If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
|
|
1255
|
-
If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
|
|
1256
|
-
"""
|
|
1257
|
-
if self.f is None:
|
|
1258
|
-
if self.dtype==torch.float64:
|
|
1259
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
1260
|
-
else:
|
|
1261
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
1262
|
-
|
|
1263
|
-
if cell_ids is None:
|
|
1264
|
-
dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=max_poll)
|
|
1265
|
-
return dim,cdim
|
|
1266
|
-
|
|
1267
|
-
if nside is None:
|
|
1268
|
-
nside = self.nside
|
|
1269
|
-
|
|
1270
|
-
# var-length mode: list/tuple of ids, one per sample
|
|
1271
|
-
if self._is_varlength_batch(cell_ids):
|
|
1272
|
-
outs, outs_ids = [], []
|
|
1273
|
-
B = len(cell_ids)
|
|
1274
|
-
for b in range(B):
|
|
1275
|
-
cid_b = self._to_numpy_1d(cell_ids[b])
|
|
1276
|
-
# extraire le bon échantillon d'`im`
|
|
1277
|
-
if torch.is_tensor(im):
|
|
1278
|
-
xb = im[b:b+1] # (1, C, N_b)
|
|
1279
|
-
yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
1280
|
-
outs.append(yb.squeeze(0)) # (C, N_b')
|
|
1281
|
-
else:
|
|
1282
|
-
# si im est déjà une liste de (C, N_b)
|
|
1283
|
-
xb = im[b]
|
|
1284
|
-
yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
1285
|
-
outs.append(yb.squeeze(0))
|
|
1286
|
-
outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
|
|
1287
|
-
return outs, outs_ids
|
|
1288
|
-
|
|
1289
|
-
# grille commune (un seul vecteur d'ids)
|
|
1290
|
-
cid = self._to_numpy_1d(cell_ids)
|
|
1291
|
-
return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
|
|
1292
|
-
|
|
1293
|
-
def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
|
|
1294
|
-
"""
|
|
1295
|
-
If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
|
|
1296
|
-
If they are lists (var-length per sample) -> return list[Tensor].
|
|
1297
|
-
"""
|
|
1298
|
-
if self.f is None:
|
|
1299
|
-
if self.dtype==torch.float64:
|
|
1300
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
1301
|
-
else:
|
|
1302
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
1303
|
-
|
|
1304
|
-
if cell_ids is None:
|
|
1305
|
-
dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
|
|
1306
|
-
return dim
|
|
1307
|
-
|
|
1308
|
-
if nside is None:
|
|
1309
|
-
nside = self.nside
|
|
1310
|
-
|
|
1311
|
-
# var-length: listes parallèles
|
|
1312
|
-
if self._is_varlength_batch(cell_ids):
|
|
1313
|
-
assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
|
|
1314
|
-
"In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
|
|
1315
|
-
outs = []
|
|
1316
|
-
B = len(cell_ids)
|
|
1317
|
-
for b in range(B):
|
|
1318
|
-
cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
|
|
1319
|
-
ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
|
|
1320
|
-
if torch.is_tensor(im):
|
|
1321
|
-
xb = im[b:b+1] # (1, C, N_b_coarse)
|
|
1322
|
-
yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
|
|
1323
|
-
o_cell_ids=ocid_b, force_init_index=True)
|
|
1324
|
-
outs.append(yb.squeeze(0)) # (C, N_b_fine)
|
|
1325
|
-
else:
|
|
1326
|
-
xb = im[b] # (C, N_b_coarse)
|
|
1327
|
-
yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
|
|
1328
|
-
o_cell_ids=ocid_b, force_init_index=True)
|
|
1329
|
-
outs.append(yb.squeeze(0))
|
|
1330
|
-
return outs
|
|
1331
|
-
|
|
1332
|
-
# grille commune
|
|
1333
|
-
cid = self._to_numpy_1d(cell_ids)
|
|
1334
|
-
ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
|
|
1335
|
-
return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
|
|
1336
|
-
o_cell_ids=ocid, force_init_index=True)
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
1121
|
def to_tensor(self,x):
|
|
1340
|
-
return torch.tensor(x,device=
|
|
1341
|
-
|
|
1122
|
+
return torch.tensor(x,device='cuda')
|
|
1123
|
+
|
|
1342
1124
|
def to_numpy(self,x):
|
|
1343
1125
|
if isinstance(x,np.ndarray):
|
|
1344
1126
|
return x
|
|
1345
|
-
return x.cpu().numpy()
|
|
1346
|
-
|
|
1127
|
+
return x.cpu().numpy()
|