foscat 2025.9.1__py3-none-any.whl → 2025.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +160 -93
- foscat/FoCUS.py +80 -267
- foscat/HOrientedConvol.py +233 -250
- foscat/HealBili.py +12 -8
- foscat/Plot.py +1112 -142
- foscat/SphericalStencil.py +1346 -0
- foscat/UNET.py +21 -7
- foscat/healpix_unet_torch.py +656 -171
- foscat/scat_cov.py +2 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/METADATA +1 -1
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/RECORD +14 -13
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/WHEEL +0 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1346 @@
|
|
|
1
|
+
# SPDX-License-Identifier: MIT
|
|
2
|
+
# Author: J.-M. Delouis
|
|
3
|
+
import numpy as np
|
|
4
|
+
import healpy as hp
|
|
5
|
+
import foscat.scat_cov as sc
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import healpy as hp
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SphericalStencil:
|
|
14
|
+
"""
|
|
15
|
+
GPU-accelerated spherical stencil operator for HEALPix convolutions.
|
|
16
|
+
|
|
17
|
+
This class implements three phases:
|
|
18
|
+
A) Geometry preparation: build local rotated stencil vectors for each target
|
|
19
|
+
pixel, compute HEALPix neighbor indices and interpolation weights.
|
|
20
|
+
B) Sparse binding: map neighbor indices/weights to available data samples
|
|
21
|
+
(sorted ids), and normalize weights.
|
|
22
|
+
C) Convolution: apply multi-channel kernels to sparse gathered data.
|
|
23
|
+
|
|
24
|
+
Once A+B are prepared, multiple convolutions (C) can be applied efficiently
|
|
25
|
+
on the GPU.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
nside : int
|
|
30
|
+
HEALPix resolution parameter.
|
|
31
|
+
kernel_sz : int
|
|
32
|
+
Size of local stencil (must be odd, e.g. 3, 5, 7).
|
|
33
|
+
gauge_type : str
|
|
34
|
+
Type of gauge :
|
|
35
|
+
'cosmo' use the same definition than
|
|
36
|
+
https://www.aanda.org/articles/aa/abs/2022/12/aa44566-22/aa44566-22.html
|
|
37
|
+
'phi' is define at the pole, could be better for earth observation not using intensivly the pole
|
|
38
|
+
n_gauge : float
|
|
39
|
+
Number of oriented gauges (Default 1).
|
|
40
|
+
blend : bool
|
|
41
|
+
Whether to blend smoothly between axisA and axisB (dual gauge).
|
|
42
|
+
power : float
|
|
43
|
+
Sharpness of blend transition (dual gauge).
|
|
44
|
+
nest : bool
|
|
45
|
+
Use nested ordering if True (default), else ring ordering.
|
|
46
|
+
cell_ids : np.ndarray | torch.Tensor | None
|
|
47
|
+
If given, initialize Step A immediately for these targets.
|
|
48
|
+
device : torch.device | str | None
|
|
49
|
+
Default device (if None, 'cuda' if available else 'cpu').
|
|
50
|
+
dtype : torch.dtype | None
|
|
51
|
+
Default dtype (float32 if None).
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
nside: int,
|
|
57
|
+
kernel_sz: int,
|
|
58
|
+
*,
|
|
59
|
+
nest: bool = True,
|
|
60
|
+
cell_ids=None,
|
|
61
|
+
device=None,
|
|
62
|
+
dtype=None,
|
|
63
|
+
n_gauges=1,
|
|
64
|
+
gauge_type='cosmo',
|
|
65
|
+
scat_op=None,
|
|
66
|
+
):
|
|
67
|
+
assert kernel_sz >= 1 and int(kernel_sz) == kernel_sz
|
|
68
|
+
assert kernel_sz % 2 == 1, "kernel_sz must be odd"
|
|
69
|
+
|
|
70
|
+
self.nside = int(nside)
|
|
71
|
+
self.KERNELSZ = int(kernel_sz)
|
|
72
|
+
self.P = self.KERNELSZ * self.KERNELSZ
|
|
73
|
+
|
|
74
|
+
self.G = n_gauges
|
|
75
|
+
self.gauge_type=gauge_type
|
|
76
|
+
|
|
77
|
+
self.nest = bool(nest)
|
|
78
|
+
if scat_op is None:
|
|
79
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ)
|
|
80
|
+
else:
|
|
81
|
+
self.f=scat_op
|
|
82
|
+
|
|
83
|
+
# Torch defaults
|
|
84
|
+
if device is None:
|
|
85
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
86
|
+
if dtype is None:
|
|
87
|
+
dtype = torch.float32
|
|
88
|
+
self.device = torch.device(device)
|
|
89
|
+
self.dtype = dtype
|
|
90
|
+
|
|
91
|
+
# Geometry cache
|
|
92
|
+
self.Kb = None
|
|
93
|
+
self.idx_t = None # (4, K*P) neighbor indices
|
|
94
|
+
self.w_t = None # (4, K*P) interpolation weights
|
|
95
|
+
self.ids_sorted_np = None
|
|
96
|
+
self.pos_safe_t = None
|
|
97
|
+
self.w_norm_t = None
|
|
98
|
+
self.present_t = None
|
|
99
|
+
|
|
100
|
+
# Optionnel : on garde une copie des ids par défaut si fournis
|
|
101
|
+
self.cell_ids_default = None
|
|
102
|
+
|
|
103
|
+
# ---- Optional immediate preparation (Step A+B at init) ----
|
|
104
|
+
if cell_ids is not None:
|
|
105
|
+
# Keep a copy of the default target grid (fast-path later)
|
|
106
|
+
cid = np.asarray(cell_ids, dtype=np.int64).reshape(-1)
|
|
107
|
+
self.cell_ids_default = cid.copy()
|
|
108
|
+
|
|
109
|
+
# Step A (Torch): build geometry for this grid with G gauges
|
|
110
|
+
th, ph = hp.pix2ang(self.nside, cid, nest=self.nest)
|
|
111
|
+
self.prepare_torch(th, ph, G=self.G) # fills idx_t/_multi and w_t/_multi
|
|
112
|
+
|
|
113
|
+
# Step B (Torch): bind sparse mapping on the class device/dtype
|
|
114
|
+
order = np.argsort(cid)
|
|
115
|
+
self.ids_sorted_np = cid[order] # cache for fast-path
|
|
116
|
+
|
|
117
|
+
if self.G > 1:
|
|
118
|
+
# Multi-gauge binding (produces pos_safe_t_multi, w_norm_t_multi)
|
|
119
|
+
self.bind_support_torch_multi(
|
|
120
|
+
self.ids_sorted_np,
|
|
121
|
+
device=self.device,
|
|
122
|
+
dtype=self.dtype,
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
# Single-gauge binding (produces pos_safe_t, w_norm_t)
|
|
126
|
+
self.bind_support_torch(
|
|
127
|
+
self.ids_sorted_np,
|
|
128
|
+
device=self.device,
|
|
129
|
+
dtype=self.dtype,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# ------------------------------------------------------------------
|
|
134
|
+
# Rotation construction in Torch
|
|
135
|
+
# ------------------------------------------------------------------
|
|
136
|
+
@staticmethod
|
|
137
|
+
def _rotation_total_torch(th, ph, alpha=None, G: int = 1, gauge_cosmo=True,device=None, dtype=None):
|
|
138
|
+
"""
|
|
139
|
+
Build a batch of rotation matrices with *G gauges* per target.
|
|
140
|
+
|
|
141
|
+
Column-vector convention: v' = R @ v.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
th : array-like (N,)
|
|
146
|
+
Colatitude.
|
|
147
|
+
ph : array-like (N,)
|
|
148
|
+
Longitude.
|
|
149
|
+
alpha : array-like (N,) or scalar or None
|
|
150
|
+
Base gauge rotation angle around the local normal.
|
|
151
|
+
If None -> 0. For each gauge g in [0..G-1], we add g*pi/G.
|
|
152
|
+
G : int
|
|
153
|
+
Number of gauges to generate per target (>=1).
|
|
154
|
+
device, dtype : torch device/dtype
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
R_tot : torch.Tensor, shape (N, G, 3, 3)
|
|
159
|
+
For each target i and gauge g, the matrix:
|
|
160
|
+
R_tot[i,g] = R_gauge(alpha[i] + g*pi/G) @ Rz(ph[i]) @ Ry(th[i])
|
|
161
|
+
"""
|
|
162
|
+
assert G >= 1, "G must be >= 1"
|
|
163
|
+
|
|
164
|
+
# ---- to torch 1D
|
|
165
|
+
th = torch.as_tensor(th, device=device, dtype=dtype).view(-1)
|
|
166
|
+
ph = torch.as_tensor(ph, device=device, dtype=dtype).view(-1)
|
|
167
|
+
if alpha is None:
|
|
168
|
+
alpha = torch.zeros_like(th)
|
|
169
|
+
else:
|
|
170
|
+
alpha = torch.as_tensor(alpha, device=device, dtype=dtype).view(-1)
|
|
171
|
+
|
|
172
|
+
device = th.device
|
|
173
|
+
dtype = th.dtype
|
|
174
|
+
N = th.shape[0]
|
|
175
|
+
|
|
176
|
+
# ---- base rotation R_base = Rz(ph) @ Ry(th), shape (N,3,3)
|
|
177
|
+
ct, st = torch.cos(th), torch.sin(th)
|
|
178
|
+
cp, sp = torch.cos(ph), torch.sin(ph)
|
|
179
|
+
|
|
180
|
+
R_base = torch.zeros((N, 3, 3), device=device, dtype=dtype)
|
|
181
|
+
# row 0
|
|
182
|
+
R_base[:, 0, 0] = cp * ct
|
|
183
|
+
R_base[:, 0, 1] = -sp
|
|
184
|
+
R_base[:, 0, 2] = cp * st
|
|
185
|
+
# row 1
|
|
186
|
+
R_base[:, 1, 0] = sp * ct
|
|
187
|
+
R_base[:, 1, 1] = cp
|
|
188
|
+
R_base[:, 1, 2] = sp * st
|
|
189
|
+
# row 2
|
|
190
|
+
R_base[:, 2, 0] = -st
|
|
191
|
+
R_base[:, 2, 1] = 0.0
|
|
192
|
+
R_base[:, 2, 2] = ct
|
|
193
|
+
|
|
194
|
+
# local normal n = third column of R_base, shape (N,3)
|
|
195
|
+
n = R_base[:, :, 2]
|
|
196
|
+
n = n / torch.linalg.norm(n, dim=1, keepdim=True).clamp_min(1e-12) # safe normalize
|
|
197
|
+
|
|
198
|
+
# per-target sign: +1 if th <= pi/2 else -1
|
|
199
|
+
sign = torch.where(th <= (np.pi/2), torch.ones_like(th), -torch.ones_like(th)) # (N,)
|
|
200
|
+
|
|
201
|
+
# base gauge shifts (always positive)
|
|
202
|
+
g_shifts = torch.arange(G, device=device, dtype=dtype) * (np.pi / G) # (G,)
|
|
203
|
+
|
|
204
|
+
# broadcast with sign: (N,G)
|
|
205
|
+
if gauge_cosmo:
|
|
206
|
+
alpha_g = alpha[:, None] + sign[:, None] * g_shifts[None, :]
|
|
207
|
+
else:
|
|
208
|
+
alpha_g = alpha[:, None] + g_shifts[None, :]
|
|
209
|
+
|
|
210
|
+
ca = torch.cos(alpha_g) # (N,G)
|
|
211
|
+
sa = torch.sin(alpha_g) # (N,G)
|
|
212
|
+
|
|
213
|
+
# ---- expand normal to (N,G,3)
|
|
214
|
+
n_g = n[:, None, :].expand(N, G, 3) # (N,G,3)
|
|
215
|
+
nx, ny, nz = n_g[..., 0], n_g[..., 1], n_g[..., 2]
|
|
216
|
+
|
|
217
|
+
# skew-symmetric K(n_g), shape (N,G,3,3)
|
|
218
|
+
K = torch.zeros((N, G, 3, 3), device=device, dtype=dtype)
|
|
219
|
+
K[..., 0, 1] = -nz; K[..., 0, 2] = ny
|
|
220
|
+
K[..., 1, 0] = nz; K[..., 1, 2] = -nx
|
|
221
|
+
K[..., 2, 0] = -ny; K[..., 2, 1] = nx
|
|
222
|
+
|
|
223
|
+
# outer(n,n) and identity
|
|
224
|
+
outer = n_g.unsqueeze(-1) * n_g.unsqueeze(-2) # (N,G,3,3)
|
|
225
|
+
I = torch.eye(3, device=device, dtype=dtype).view(1,1,3,3).expand(N, G, 3, 3)
|
|
226
|
+
|
|
227
|
+
# ---- Rodrigues per gauge: R_gauge(N,G,3,3)
|
|
228
|
+
R_gauge = I * ca.view(N, G, 1, 1) + K * sa.view(N, G, 1, 1) + \
|
|
229
|
+
outer * (1.0 - ca).view(N, G, 1, 1)
|
|
230
|
+
|
|
231
|
+
# ---- broadcast multiply with base: R_base_g(N,G,3,3)
|
|
232
|
+
R_base_g = R_base.unsqueeze(1).expand(N, G, 3, 3)
|
|
233
|
+
R_tot = torch.matmul(R_gauge, R_base_g) # (N,G,3,3)
|
|
234
|
+
return R_tot
|
|
235
|
+
|
|
236
|
+
# ------------------------------------------------------------------
|
|
237
|
+
# Torch-based get_interp_weights wrapper
|
|
238
|
+
# ------------------------------------------------------------------
|
|
239
|
+
@staticmethod
|
|
240
|
+
def get_interp_weights_from_vec_torch(
|
|
241
|
+
nside: int,
|
|
242
|
+
vec,
|
|
243
|
+
*,
|
|
244
|
+
nest: bool = True,
|
|
245
|
+
device=None,
|
|
246
|
+
dtype=None,
|
|
247
|
+
chunk_size=1_000_000,
|
|
248
|
+
):
|
|
249
|
+
"""
|
|
250
|
+
Torch wrapper for healpy.get_interp_weights using input vectors.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
nside : int
|
|
255
|
+
HEALPix resolution.
|
|
256
|
+
vec : torch.Tensor (...,3)
|
|
257
|
+
Direction vectors (not necessarily normalized).
|
|
258
|
+
nest : bool
|
|
259
|
+
Nested ordering if True (default).
|
|
260
|
+
device, dtype : Torch device/dtype.
|
|
261
|
+
chunk_size : int
|
|
262
|
+
Number of points per healpy call on CPU.
|
|
263
|
+
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
idx_t : LongTensor (4, *leading)
|
|
267
|
+
w_t : Tensor (4, *leading)
|
|
268
|
+
"""
|
|
269
|
+
if not isinstance(vec, torch.Tensor):
|
|
270
|
+
vec = torch.as_tensor(vec, device=device, dtype=dtype)
|
|
271
|
+
else:
|
|
272
|
+
device = vec.device if device is None else device
|
|
273
|
+
dtype = vec.dtype if dtype is None else dtype
|
|
274
|
+
vec = vec.to(device=device, dtype=dtype)
|
|
275
|
+
|
|
276
|
+
orig_shape = vec.shape[:-1]
|
|
277
|
+
M = int(np.prod(orig_shape)) if len(orig_shape) else 1
|
|
278
|
+
v = vec.reshape(M, 3)
|
|
279
|
+
|
|
280
|
+
eps = torch.finfo(vec.dtype).eps
|
|
281
|
+
r = torch.linalg.norm(v, dim=1, keepdim=True).clamp_min(eps)
|
|
282
|
+
v_unit = v / r
|
|
283
|
+
x, y, z = v_unit[:, 0], v_unit[:, 1], v_unit[:, 2]
|
|
284
|
+
|
|
285
|
+
theta = torch.acos(z.clamp(-1.0, 1.0))
|
|
286
|
+
phi = torch.atan2(y, x)
|
|
287
|
+
two_pi = torch.tensor(2*np.pi, device=device, dtype=dtype)
|
|
288
|
+
phi = (phi % two_pi)
|
|
289
|
+
|
|
290
|
+
theta_np = theta.detach().cpu().numpy()
|
|
291
|
+
phi_np = phi.detach().cpu().numpy()
|
|
292
|
+
|
|
293
|
+
idx_accum, w_accum = [], []
|
|
294
|
+
for start in range(0, M, chunk_size):
|
|
295
|
+
stop = min(start + chunk_size, M)
|
|
296
|
+
t_chunk, p_chunk = theta_np[start:stop], phi_np[start:stop]
|
|
297
|
+
idx_np, w_np = hp.get_interp_weights(nside, t_chunk, p_chunk, nest=nest)
|
|
298
|
+
idx_accum.append(idx_np)
|
|
299
|
+
w_accum.append(w_np)
|
|
300
|
+
|
|
301
|
+
idx_np_all = np.concatenate(idx_accum, axis=1) if len(idx_accum) > 1 else idx_accum[0]
|
|
302
|
+
w_np_all = np.concatenate(w_accum, axis=1) if len(w_accum) > 1 else w_accum[0]
|
|
303
|
+
|
|
304
|
+
idx_t = torch.as_tensor(idx_np_all, device=device, dtype=torch.long)
|
|
305
|
+
w_t = torch.as_tensor(w_np_all, device=device, dtype=dtype)
|
|
306
|
+
|
|
307
|
+
if len(orig_shape):
|
|
308
|
+
idx_t = idx_t.view(4, *orig_shape)
|
|
309
|
+
w_t = w_t.view(4, *orig_shape)
|
|
310
|
+
|
|
311
|
+
return idx_t, w_t
|
|
312
|
+
|
|
313
|
+
# ------------------------------------------------------------------
|
|
314
|
+
# Step A: geometry preparation fully in Torch
|
|
315
|
+
# ------------------------------------------------------------------
|
|
316
|
+
def prepare_torch(self, th, ph, alpha=None, G: int = 1):
|
|
317
|
+
"""
|
|
318
|
+
Prepare rotated stencil and HEALPix neighbors/weights in Torch for *G gauges*.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
th, ph : array-like, shape (K,)
|
|
323
|
+
Target colatitudes/longitudes.
|
|
324
|
+
alpha : array-like (K,) or scalar or None
|
|
325
|
+
Base gauge angle about the local normal at each target. If None -> 0.
|
|
326
|
+
For each gauge g in [0..G-1], the effective angle is alpha + g*pi/G.
|
|
327
|
+
G : int (>=1)
|
|
328
|
+
Number of gauges to generate per target.
|
|
329
|
+
|
|
330
|
+
Side effects
|
|
331
|
+
------------
|
|
332
|
+
Sets:
|
|
333
|
+
- self.Kb = K
|
|
334
|
+
- self.G = G
|
|
335
|
+
- self.idx_t_multi : (G, 4, K*P) LongTensor (neighbors per gauge)
|
|
336
|
+
- self.w_t_multi : (G, 4, K*P) Tensor (weights per gauge)
|
|
337
|
+
- For backward compat when G==1:
|
|
338
|
+
self.idx_t : (4, K*P)
|
|
339
|
+
self.w_t : (4, K*P)
|
|
340
|
+
|
|
341
|
+
Returns
|
|
342
|
+
-------
|
|
343
|
+
idx_t_multi : torch.LongTensor, shape (G, 4, K*P)
|
|
344
|
+
w_t_multi : torch.Tensor, shape (G, 4, K*P)
|
|
345
|
+
"""
|
|
346
|
+
# --- sanitize inputs on CPU (angles) then use class device/dtype
|
|
347
|
+
th = np.asarray(th, float).reshape(-1)
|
|
348
|
+
ph = np.asarray(ph, float).reshape(-1)
|
|
349
|
+
K = th.size
|
|
350
|
+
self.Kb = K
|
|
351
|
+
self.G = int(G)
|
|
352
|
+
assert self.G >= 1, "G must be >= 1"
|
|
353
|
+
|
|
354
|
+
# --- build the local (P,3) stencil once on device
|
|
355
|
+
P = self.P
|
|
356
|
+
vec_np = np.zeros((P, 3), dtype=float)
|
|
357
|
+
grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2) / self.nside
|
|
358
|
+
vec_np[:, 0] = np.tile(grid, self.KERNELSZ)
|
|
359
|
+
vec_np[:, 1] = np.repeat(grid, self.KERNELSZ)
|
|
360
|
+
vec_np[:, 2] = 1.0 - np.sqrt(vec_np[:, 0]**2 + vec_np[:, 1]**2)
|
|
361
|
+
vec_t = torch.as_tensor(vec_np, device=self.device, dtype=self.dtype) # (P,3)
|
|
362
|
+
|
|
363
|
+
# --- rotation matrices for all targets & gauges: (K,G,3,3)
|
|
364
|
+
if alpha is None:
|
|
365
|
+
if self.gauge_type=='cosmo':
|
|
366
|
+
alpha=2*((th>np.pi/2)-0.5)*ph
|
|
367
|
+
else:
|
|
368
|
+
alpha=0.0*th
|
|
369
|
+
|
|
370
|
+
R_t = self._rotation_total_torch(
|
|
371
|
+
th, ph, alpha, G=self.G, gauge_cosmo=(self.gauge_type=='cosmo'),
|
|
372
|
+
device=self.device, dtype=self.dtype
|
|
373
|
+
) # shape (K,G,3,3)
|
|
374
|
+
|
|
375
|
+
# --- rotate stencil for each (target, gauge): (K,G,P,3)
|
|
376
|
+
# einsum over local stencil (P,3) with rotation (K,G,3,3)
|
|
377
|
+
rotated = torch.einsum('kgij,pj->kgpi', R_t, vec_t) # (K,G,P,3)
|
|
378
|
+
|
|
379
|
+
# --- query HEALPix (neighbors+weights) in one call over (K*G*P)
|
|
380
|
+
rotated_flat = rotated.reshape(-1, 3) # (K*G*P, 3)
|
|
381
|
+
idx_t, w_t = self.get_interp_weights_from_vec_torch(
|
|
382
|
+
self.nside,
|
|
383
|
+
rotated_flat,
|
|
384
|
+
nest=self.nest,
|
|
385
|
+
device=self.device,
|
|
386
|
+
dtype=self.dtype,
|
|
387
|
+
) # each (4, K*G*P)
|
|
388
|
+
|
|
389
|
+
# --- reshape back to split gauges:
|
|
390
|
+
# current: (4, K*G*P) -> (4, K, G, P) -> (G, 4, K, P) -> (G, 4, K*P)
|
|
391
|
+
idx_t = idx_t.view(4, K, self.G, P).permute(2, 0, 1, 3).reshape(self.G, 4, K*P)
|
|
392
|
+
w_t = w_t.view(4, K, self.G, P).permute(2, 0, 1, 3).reshape(self.G, 4, K*P)
|
|
393
|
+
|
|
394
|
+
# --- cache multi-gauge versions
|
|
395
|
+
self.idx_t_multi = idx_t # (G, 4, K*P)
|
|
396
|
+
self.w_t_multi = w_t # (G, 4, K*P)
|
|
397
|
+
|
|
398
|
+
# --- backward compatibility: when G==1, also fill single-gauge fields
|
|
399
|
+
if self.G == 1:
|
|
400
|
+
self.idx_t = idx_t[0] # (4, K*P)
|
|
401
|
+
self.w_t = w_t[0] # (4, K*P)
|
|
402
|
+
else:
|
|
403
|
+
# when multi-gauge, you can pick a default (e.g., gauge 0) if legacy code asks
|
|
404
|
+
# but better to adapt bind/apply to consume the multi-gauge tensors.
|
|
405
|
+
self.idx_t = None
|
|
406
|
+
self.w_t = None
|
|
407
|
+
|
|
408
|
+
return self.idx_t_multi, self.w_t_multi
|
|
409
|
+
|
|
410
|
+
def bind_support_torch_multi(self, ids_sorted_np, *, device=None, dtype=None):
|
|
411
|
+
"""
|
|
412
|
+
Multi-gauge sparse binding (Step B) AVEC logique 'domaine réduit':
|
|
413
|
+
- poids des voisins hors domaine mis à 0
|
|
414
|
+
- renormalisation par colonne à 1
|
|
415
|
+
- si colonne vide: fallback sur le pixel cible (centre du stencil)
|
|
416
|
+
|
|
417
|
+
Produit:
|
|
418
|
+
self.pos_safe_t_multi : (G, 4, K*P)
|
|
419
|
+
self.w_norm_t_multi : (G, 4, K*P)
|
|
420
|
+
self.present_t_multi : (G, 4, K*P)
|
|
421
|
+
"""
|
|
422
|
+
assert hasattr(self, 'idx_t_multi') and self.idx_t_multi is not None, \
|
|
423
|
+
"Call prepare_torch(..., G>0) before bind_support_torch_multi(...)"
|
|
424
|
+
assert hasattr(self, 'w_t_multi') and self.w_t_multi is not None
|
|
425
|
+
|
|
426
|
+
if device is None: device = self.device
|
|
427
|
+
if dtype is None: dtype = self.dtype
|
|
428
|
+
|
|
429
|
+
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64).reshape(-1)
|
|
430
|
+
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
|
|
431
|
+
|
|
432
|
+
G, _, M = self.idx_t_multi.shape
|
|
433
|
+
K = self.Kb
|
|
434
|
+
P = self.P
|
|
435
|
+
assert M == K*P, "idx_t_multi second axis must have K*P columns"
|
|
436
|
+
|
|
437
|
+
# index du centre du stencil (en flatten P)
|
|
438
|
+
p_ref = (self.KERNELSZ // 2) * (self.KERNELSZ + 1) # ex. 5 -> 12
|
|
439
|
+
|
|
440
|
+
pos_list, present_list, wnorm_list = [], [], []
|
|
441
|
+
|
|
442
|
+
for g in range(G):
|
|
443
|
+
idx = self.idx_t_multi[g].to(device=device, dtype=torch.long) # (4, M)
|
|
444
|
+
w = self.w_t_multi[g].to(device=device, dtype=dtype) # (4, M)
|
|
445
|
+
|
|
446
|
+
# positions dans ids_sorted
|
|
447
|
+
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
|
|
448
|
+
in_range = pos < ids_sorted.numel()
|
|
449
|
+
cmp_vals = torch.full_like(idx, -1)
|
|
450
|
+
cmp_vals[in_range] = ids_sorted[pos[in_range]]
|
|
451
|
+
present = (cmp_vals == idx) # (4, M) bool
|
|
452
|
+
|
|
453
|
+
# Colonnes sans AUCUN voisin présent
|
|
454
|
+
empty_cols = ~present.any(dim=0) # (M,)
|
|
455
|
+
if empty_cols.any():
|
|
456
|
+
p_ref = (self.KERNELSZ // 2) * (self.KERNELSZ + 1)
|
|
457
|
+
k_id = torch.div(torch.arange(M, device=device), P, rounding_mode='floor') # (M,)
|
|
458
|
+
ref_cols = k_id * P + p_ref
|
|
459
|
+
src = ref_cols[empty_cols]
|
|
460
|
+
|
|
461
|
+
# copie idx/w de la colonne 'centre'
|
|
462
|
+
idx[:, empty_cols] = idx[:, src]
|
|
463
|
+
w[:, empty_cols] = w[:, src]
|
|
464
|
+
|
|
465
|
+
# --- Recompute presence/pos safely on those columns
|
|
466
|
+
idx_e = idx[:, empty_cols].reshape(-1) # (4*M_empty,)
|
|
467
|
+
pos_e = torch.searchsorted(ids_sorted, idx_e) # (4*M_empty,)
|
|
468
|
+
valid_e = pos_e < ids_sorted.numel()
|
|
469
|
+
pos_e_clipped = pos_e.clamp_max(max(ids_sorted.numel()-1, 0)).to(torch.long)
|
|
470
|
+
cmp_e = ids_sorted[pos_e_clipped]
|
|
471
|
+
present_e = valid_e & (cmp_e == idx_e) # (4*M_empty,)
|
|
472
|
+
|
|
473
|
+
present[:, empty_cols] = present_e.view(4, -1)
|
|
474
|
+
pos[:, empty_cols] = pos_e_clipped.view(4, -1)
|
|
475
|
+
|
|
476
|
+
# Met à zéro les poids absents puis renormalise à 1 par colonne
|
|
477
|
+
w = w * present
|
|
478
|
+
colsum = w.sum(dim=0, keepdim=True)
|
|
479
|
+
zero_cols = (colsum == 0)
|
|
480
|
+
if zero_cols.any():
|
|
481
|
+
w[0, zero_cols[0]] = present[0, zero_cols[0]].to(w.dtype)
|
|
482
|
+
colsum = w.sum(dim=0, keepdim=True)
|
|
483
|
+
w_norm = w / colsum.clamp_min(1e-12)
|
|
484
|
+
|
|
485
|
+
pos_safe = torch.where(present, pos, torch.zeros_like(pos))
|
|
486
|
+
|
|
487
|
+
pos_list.append(pos_safe)
|
|
488
|
+
present_list.append(present)
|
|
489
|
+
wnorm_list.append(w_norm)
|
|
490
|
+
|
|
491
|
+
self.pos_safe_t_multi = torch.stack(pos_list, dim=0) # (G, 4, M)
|
|
492
|
+
self.present_t_multi = torch.stack(present_list, dim=0) # (G, 4, M)
|
|
493
|
+
self.w_norm_t_multi = torch.stack(wnorm_list, dim=0) # (G, 4, M)
|
|
494
|
+
|
|
495
|
+
# miroir device/dtype runtime
|
|
496
|
+
self.device = device
|
|
497
|
+
self.dtype = dtype
|
|
498
|
+
|
|
499
|
+
def bind_support_torch(self, ids_sorted_np, *, device=None, dtype=None):
|
|
500
|
+
"""
|
|
501
|
+
Single-gauge sparse binding (Step B) AVEC logique 'domaine réduit':
|
|
502
|
+
- poids des voisins hors domaine mis à 0
|
|
503
|
+
- renormalisation par colonne à 1
|
|
504
|
+
- si colonne vide: fallback sur le pixel cible (centre du stencil)
|
|
505
|
+
"""
|
|
506
|
+
if device is None:
|
|
507
|
+
device = self.device
|
|
508
|
+
if dtype is None:
|
|
509
|
+
dtype = self.dtype
|
|
510
|
+
|
|
511
|
+
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64)
|
|
512
|
+
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
|
|
513
|
+
|
|
514
|
+
idx = self.idx_t.to(device=device, dtype=torch.long) # (4, K*P)
|
|
515
|
+
w = self.w_t.to(device=device, dtype=dtype) # (4, K*P)
|
|
516
|
+
|
|
517
|
+
K = self.Kb
|
|
518
|
+
P = self.P
|
|
519
|
+
M = K * P
|
|
520
|
+
|
|
521
|
+
# positions dans ids_sorted
|
|
522
|
+
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
|
|
523
|
+
in_range = pos < ids_sorted.shape[0]
|
|
524
|
+
cmp_vals = torch.full_like(idx, -1)
|
|
525
|
+
cmp_vals[in_range] = ids_sorted[pos[in_range]]
|
|
526
|
+
present = (cmp_vals == idx) # (4, M)
|
|
527
|
+
|
|
528
|
+
# Fallback colonnes vides -> centre du stencil
|
|
529
|
+
p_ref = (self.KERNELSZ // 2) * (self.KERNELSZ + 1)
|
|
530
|
+
empty_cols = ~present.any(dim=0) # (M,)
|
|
531
|
+
if empty_cols.any():
|
|
532
|
+
k_id = torch.div(torch.arange(M, device=device), P, rounding_mode='floor') # (M,)
|
|
533
|
+
ref_cols = k_id * P + p_ref
|
|
534
|
+
src = ref_cols[empty_cols]
|
|
535
|
+
|
|
536
|
+
# copie idx/w de la colonne 'centre'
|
|
537
|
+
idx[:, empty_cols] = idx[:, src]
|
|
538
|
+
w[:, empty_cols] = w[:, src]
|
|
539
|
+
|
|
540
|
+
# --- Recompute presence/pos safely on those columns
|
|
541
|
+
idx_e = idx[:, empty_cols].reshape(-1) # (4*M_empty,)
|
|
542
|
+
pos_e = torch.searchsorted(ids_sorted, idx_e) # (4*M_empty,)
|
|
543
|
+
# valid positions strictly inside [0, len)
|
|
544
|
+
valid_e = pos_e < ids_sorted.numel()
|
|
545
|
+
pos_e_clipped = pos_e.clamp_max(max(ids_sorted.numel()-1, 0)).to(torch.long)
|
|
546
|
+
cmp_e = ids_sorted[pos_e_clipped]
|
|
547
|
+
present_e = valid_e & (cmp_e == idx_e) # (4*M_empty,)
|
|
548
|
+
|
|
549
|
+
# reshape back
|
|
550
|
+
present[:, empty_cols] = present_e.view(4, -1)
|
|
551
|
+
pos[:, empty_cols] = pos_e_clipped.view(4, -1)
|
|
552
|
+
|
|
553
|
+
# Zéro poids absents + renormalisation à 1
|
|
554
|
+
w = w * present
|
|
555
|
+
colsum = w.sum(dim=0, keepdim=True)
|
|
556
|
+
zero_cols = (colsum == 0)
|
|
557
|
+
if zero_cols.any():
|
|
558
|
+
# force 1 sur la première ligne disponible (ici ligne 0)
|
|
559
|
+
w[0, zero_cols[0]] = present[0, zero_cols[0]].to(w.dtype)
|
|
560
|
+
colsum = w.sum(dim=0, keepdim=True)
|
|
561
|
+
w_norm = w / colsum.clamp_min(1e-12)
|
|
562
|
+
|
|
563
|
+
self.pos_safe_t = torch.where(present, pos, torch.zeros_like(pos))
|
|
564
|
+
self.w_norm_t = w_norm
|
|
565
|
+
self.present_t = present
|
|
566
|
+
|
|
567
|
+
self.device = device
|
|
568
|
+
self.dtype = dtype
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
'''
|
|
572
|
+
def bind_support_torch_multi(self, ids_sorted_np, *, device=None, dtype=None):
|
|
573
|
+
"""
|
|
574
|
+
Multi-gauge sparse binding (Step B).
|
|
575
|
+
Uses self.idx_t_multi / self.w_t_multi prepared by prepare_torch(..., G>1)
|
|
576
|
+
and builds, for each gauge g, (pos_safe, w_norm, present).
|
|
577
|
+
|
|
578
|
+
Parameters
|
|
579
|
+
----------
|
|
580
|
+
ids_sorted_np : np.ndarray (K,)
|
|
581
|
+
Sorted pixel ids for available samples (matches the last axis of your data).
|
|
582
|
+
device, dtype : torch device/dtype for the produced mapping tensors.
|
|
583
|
+
|
|
584
|
+
Side effects
|
|
585
|
+
------------
|
|
586
|
+
Sets:
|
|
587
|
+
- self.ids_sorted_np : (K,)
|
|
588
|
+
- self.pos_safe_t_multi : (G, 4, K*P) LongTensor
|
|
589
|
+
- self.w_norm_t_multi : (G, 4, K*P) Tensor
|
|
590
|
+
- self.present_t_multi : (G, 4, K*P) BoolTensor
|
|
591
|
+
- (and mirrors device/dtype in self.device/self.dtype)
|
|
592
|
+
"""
|
|
593
|
+
assert hasattr(self, 'idx_t_multi') and self.idx_t_multi is not None, \
|
|
594
|
+
"Call prepare_torch(..., G>0) before bind_support_torch_multi(...)"
|
|
595
|
+
assert hasattr(self, 'w_t_multi') and self.w_t_multi is not None
|
|
596
|
+
|
|
597
|
+
if device is None: device = self.device
|
|
598
|
+
if dtype is None: dtype = self.dtype
|
|
599
|
+
|
|
600
|
+
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64).reshape(-1)
|
|
601
|
+
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
|
|
602
|
+
|
|
603
|
+
G, _, M = self.idx_t_multi.shape
|
|
604
|
+
K = self.Kb
|
|
605
|
+
P = self.P
|
|
606
|
+
assert M == K*P, "idx_t_multi second axis must have K*P columns"
|
|
607
|
+
|
|
608
|
+
pos_list, present_list, wnorm_list = [], [], []
|
|
609
|
+
|
|
610
|
+
for g in range(G):
|
|
611
|
+
idx = self.idx_t_multi[g].to(device=device, dtype=torch.long) # (4, M)
|
|
612
|
+
w = self.w_t_multi[g].to(device=device, dtype=dtype) # (4, M)
|
|
613
|
+
|
|
614
|
+
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
|
|
615
|
+
in_range = pos < ids_sorted.numel()
|
|
616
|
+
cmp_vals = torch.full_like(idx, -1)
|
|
617
|
+
cmp_vals[in_range] = ids_sorted[pos[in_range]]
|
|
618
|
+
present = (cmp_vals == idx)
|
|
619
|
+
|
|
620
|
+
# normalize weights per column after masking
|
|
621
|
+
w = w * present
|
|
622
|
+
colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
|
|
623
|
+
w_norm = w / colsum
|
|
624
|
+
|
|
625
|
+
pos_safe = torch.where(present, pos, torch.zeros_like(pos))
|
|
626
|
+
|
|
627
|
+
pos_list.append(pos_safe)
|
|
628
|
+
present_list.append(present)
|
|
629
|
+
wnorm_list.append(w_norm)
|
|
630
|
+
|
|
631
|
+
self.pos_safe_t_multi = torch.stack(pos_list, dim=0) # (G, 4, M)
|
|
632
|
+
self.present_t_multi = torch.stack(present_list, dim=0) # (G, 4, M)
|
|
633
|
+
self.w_norm_t_multi = torch.stack(wnorm_list, dim=0) # (G, 4, M)
|
|
634
|
+
|
|
635
|
+
# mirror runtime placement
|
|
636
|
+
self.device = device
|
|
637
|
+
self.dtype = dtype
|
|
638
|
+
|
|
639
|
+
# ------------------------------------------------------------------
|
|
640
|
+
# Step B: bind support Torch
|
|
641
|
+
# ------------------------------------------------------------------
|
|
642
|
+
def bind_support_torch(self, ids_sorted_np, *, device=None, dtype=None):
|
|
643
|
+
"""
|
|
644
|
+
Map HEALPix neighbor indices (from Step A) to actual data samples
|
|
645
|
+
sorted by pixel id. Produces pos_safe and normalized weights.
|
|
646
|
+
|
|
647
|
+
Parameters
|
|
648
|
+
----------
|
|
649
|
+
ids_sorted_np : np.ndarray (K,)
|
|
650
|
+
Sorted pixel ids for available data.
|
|
651
|
+
device, dtype : Torch device/dtype for results.
|
|
652
|
+
"""
|
|
653
|
+
if device is None:
|
|
654
|
+
device = self.device
|
|
655
|
+
if dtype is None:
|
|
656
|
+
dtype = self.dtype
|
|
657
|
+
|
|
658
|
+
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64)
|
|
659
|
+
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
|
|
660
|
+
|
|
661
|
+
idx = self.idx_t.to(device=device, dtype=torch.long)
|
|
662
|
+
w = self.w_t.to(device=device, dtype=dtype)
|
|
663
|
+
|
|
664
|
+
M = self.Kb * self.P
|
|
665
|
+
idx = idx.view(4, M)
|
|
666
|
+
w = w.view(4, M)
|
|
667
|
+
|
|
668
|
+
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
|
|
669
|
+
in_range = pos < ids_sorted.shape[0]
|
|
670
|
+
cmp_vals = torch.full_like(idx, -1)
|
|
671
|
+
cmp_vals[in_range] = ids_sorted[pos[in_range]]
|
|
672
|
+
present = (cmp_vals == idx)
|
|
673
|
+
|
|
674
|
+
w = w * present
|
|
675
|
+
colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
|
|
676
|
+
w_norm = w / colsum
|
|
677
|
+
|
|
678
|
+
self.pos_safe_t = torch.where(present, pos, torch.zeros_like(pos))
|
|
679
|
+
self.w_norm_t = w_norm
|
|
680
|
+
self.present_t = present
|
|
681
|
+
self.device = device
|
|
682
|
+
self.dtype = dtype
|
|
683
|
+
'''
|
|
684
|
+
# ------------------------------------------------------------------
|
|
685
|
+
# Step C: apply convolution (already Torch in your code)
|
|
686
|
+
# ------------------------------------------------------------------
|
|
687
|
+
def apply_multi(self, data_sorted_t: torch.Tensor, kernel_t: torch.Tensor):
|
|
688
|
+
"""
|
|
689
|
+
Apply multi-gauge convolution.
|
|
690
|
+
|
|
691
|
+
Inputs
|
|
692
|
+
------
|
|
693
|
+
data_sorted_t : (B, Ci, K) torch.Tensor on self.device/self.dtype
|
|
694
|
+
kernel_t : either
|
|
695
|
+
- (Ci, Co_g, P) : shared kernel for all gauges
|
|
696
|
+
- (G, Ci, Co_g, P) : per-gauge kernels
|
|
697
|
+
|
|
698
|
+
Returns
|
|
699
|
+
-------
|
|
700
|
+
out : (B, G*Co_g, K) torch.Tensor
|
|
701
|
+
"""
|
|
702
|
+
assert hasattr(self, 'pos_safe_t_multi') and self.pos_safe_t_multi is not None, \
|
|
703
|
+
"Call bind_support_torch_multi(...) before apply_multi(...)"
|
|
704
|
+
B, Ci, K = data_sorted_t.shape
|
|
705
|
+
G, _, M = self.pos_safe_t_multi.shape
|
|
706
|
+
assert M == K * self.P
|
|
707
|
+
|
|
708
|
+
# normalize kernel to per-gauge
|
|
709
|
+
if kernel_t.dim() == 3:
|
|
710
|
+
Ci_k, Co_g, P = kernel_t.shape
|
|
711
|
+
assert Ci_k == Ci and P == self.P
|
|
712
|
+
kernel_g = kernel_t[None, ...].expand(G, -1, -1, -1) # (G, Ci, Co_g, P)
|
|
713
|
+
elif kernel_t.dim() == 4:
|
|
714
|
+
Gk, Ci_k, Co_g, P = kernel_t.shape
|
|
715
|
+
assert Gk == G and Ci_k == Ci and P == self.P
|
|
716
|
+
kernel_g = kernel_t
|
|
717
|
+
else:
|
|
718
|
+
raise ValueError("kernel_t must be (Ci,Co_g,P) or (G,Ci,Co_g,P)")
|
|
719
|
+
|
|
720
|
+
outs = []
|
|
721
|
+
for g in range(G):
|
|
722
|
+
pos_safe = self.pos_safe_t_multi[g] # (4, K*P)
|
|
723
|
+
w_norm = self.w_norm_t_multi[g] # (4, K*P)
|
|
724
|
+
|
|
725
|
+
# gather four neighbors then weight -> (B,Ci,K,P)
|
|
726
|
+
vals_g = []
|
|
727
|
+
for j in range(4):
|
|
728
|
+
vj = data_sorted_t.index_select(2, pos_safe[j].reshape(-1)) # (B,Ci,K*P)
|
|
729
|
+
vj = vj.view(B, Ci, K, self.P)
|
|
730
|
+
vals_g.append(vj * w_norm[j].view(1, 1, K, self.P))
|
|
731
|
+
tmp = sum(vals_g) # (B,Ci,K,P)
|
|
732
|
+
|
|
733
|
+
# spatial+channel mixing with kernel of this gauge -> (B,Co_g,K)
|
|
734
|
+
yg = torch.einsum('bckp,cop->bok', tmp, kernel_g[g])
|
|
735
|
+
outs.append(yg)
|
|
736
|
+
|
|
737
|
+
# concat the gauges along channel dimension: (B, G*Co_g, K)
|
|
738
|
+
return torch.cat(outs, dim=1)
|
|
739
|
+
|
|
740
|
+
def apply(self, data_sorted_t, kernel_t):
|
|
741
|
+
"""
|
|
742
|
+
Apply the (Ci,Co,P) kernel to batched sparse data (B,Ci,K)
|
|
743
|
+
using precomputed pos_safe and w_norm. Runs fully on GPU.
|
|
744
|
+
|
|
745
|
+
Parameters
|
|
746
|
+
----------
|
|
747
|
+
data_sorted_t : torch.Tensor (B,Ci,K)
|
|
748
|
+
Input data aligned with ids_sorted.
|
|
749
|
+
kernel_t : torch.Tensor (Ci,Co,P)
|
|
750
|
+
Convolution kernel.
|
|
751
|
+
|
|
752
|
+
Returns
|
|
753
|
+
-------
|
|
754
|
+
out : torch.Tensor (B,Co,K)
|
|
755
|
+
"""
|
|
756
|
+
assert self.pos_safe_t is not None and self.w_norm_t is not None
|
|
757
|
+
B, Ci, K = data_sorted_t.shape
|
|
758
|
+
Ci_k, Co, P = kernel_t.shape
|
|
759
|
+
assert Ci_k == Ci and P == self.P
|
|
760
|
+
|
|
761
|
+
vals = []
|
|
762
|
+
for j in range(4):
|
|
763
|
+
vj = data_sorted_t.index_select(2, self.pos_safe_t[j].reshape(-1))
|
|
764
|
+
vj = vj.view(B, Ci, K, P)
|
|
765
|
+
vals.append(vj * self.w_norm_t[j].view(1, 1, K, P))
|
|
766
|
+
tmp = sum(vals) # (B,Ci,K,P)
|
|
767
|
+
|
|
768
|
+
out = torch.einsum('bckp,cop->bok', tmp, kernel_t)
|
|
769
|
+
return out
|
|
770
|
+
|
|
771
|
+
def _Convol_Torch(self, data: torch.Tensor, kernel: torch.Tensor, cell_ids=None) -> torch.Tensor:
|
|
772
|
+
"""
|
|
773
|
+
Convenience entry point with automatic single- or multi-gauge dispatch.
|
|
774
|
+
|
|
775
|
+
Behavior
|
|
776
|
+
--------
|
|
777
|
+
- If `cell_ids is None`: use cached geometry (prepare_torch) and sparse mapping
|
|
778
|
+
(bind_support_torch or bind_support_torch_multi) already stored in the class,
|
|
779
|
+
re-binding Step-B to `data`'s device/dtype when needed, then apply.
|
|
780
|
+
- If `cell_ids` is provided: compute geometry + sparse mapping for these cells
|
|
781
|
+
using the class' gauge setup (including the number of gauges G prepared by
|
|
782
|
+
`prepare_torch(..., G)`), reorder `data` to match the sorted ids, apply
|
|
783
|
+
(single or multi), and finally unsort to the original `cell_ids` order.
|
|
784
|
+
|
|
785
|
+
Parameters
|
|
786
|
+
----------
|
|
787
|
+
data : (B, Ci, K) torch.float
|
|
788
|
+
Sparse map values. Last axis K must equal the number of target pixels.
|
|
789
|
+
kernel : torch.Tensor
|
|
790
|
+
- Single-gauge path: (Ci, Co, P) where P = kernel_sz**2.
|
|
791
|
+
- Multi-gauge path: (Ci, Co_g, P) shared kernel for all gauges, OR
|
|
792
|
+
(G, Ci, Co_g, P) per-gauge kernels.
|
|
793
|
+
The output channels will be Co (single) or G*Co_g (multi).
|
|
794
|
+
cell_ids : Optional[np.ndarray | torch.Tensor], shape (K,)
|
|
795
|
+
Target HEALPix pixels. If None, re-use the class' cached targets.
|
|
796
|
+
|
|
797
|
+
Returns
|
|
798
|
+
-------
|
|
799
|
+
out : torch.Tensor, shape (B, Co, K)
|
|
800
|
+
Co = Co (single gauge) or Co = G*Co_g (multi-gauge).
|
|
801
|
+
"""
|
|
802
|
+
assert isinstance(data, torch.Tensor) and isinstance(kernel, torch.Tensor), \
|
|
803
|
+
"data and kernel must be torch.Tensors"
|
|
804
|
+
device = data.device
|
|
805
|
+
dtype = data.dtype
|
|
806
|
+
|
|
807
|
+
B, Ci, K_data = data.shape
|
|
808
|
+
P = self.P
|
|
809
|
+
P_k = kernel.shape[-1]
|
|
810
|
+
assert P_k == P, f"kernel P={P_k} must equal kernel_sz**2 = {P}"
|
|
811
|
+
|
|
812
|
+
def _to_np_1d(ids):
|
|
813
|
+
if isinstance(ids, torch.Tensor):
|
|
814
|
+
return ids.detach().cpu().numpy().astype(np.int64, copy=False)
|
|
815
|
+
return np.asarray(ids, dtype=np.int64).reshape(-1)
|
|
816
|
+
|
|
817
|
+
def _has_multi_bind():
|
|
818
|
+
return (getattr(self, 'G', 1) > 1 and
|
|
819
|
+
getattr(self, 'pos_safe_t_multi', None) is not None and
|
|
820
|
+
getattr(self, 'w_norm_t_multi', None) is not None)
|
|
821
|
+
|
|
822
|
+
# ----------------------------
|
|
823
|
+
# Case 1: new target ids given
|
|
824
|
+
# ----------------------------
|
|
825
|
+
if cell_ids is not None:
|
|
826
|
+
cell_ids_np = _to_np_1d(cell_ids)
|
|
827
|
+
|
|
828
|
+
# A) geometry with class' G (defaults to 1 if not set)
|
|
829
|
+
G = getattr(self, 'G', 1)
|
|
830
|
+
th, ph = hp.pix2ang(self.nside, cell_ids_np, nest=self.nest)
|
|
831
|
+
self.prepare_torch(th, ph, alpha=None, G=G) # fills idx_t/_multi, w_t/_multi
|
|
832
|
+
|
|
833
|
+
# B) sort ids and reorder data accordingly
|
|
834
|
+
order = np.argsort(cell_ids_np)
|
|
835
|
+
ids_sorted_np = cell_ids_np[order]
|
|
836
|
+
assert K_data == ids_sorted_np.size, \
|
|
837
|
+
"data last dimension must equal number of provided cell_ids"
|
|
838
|
+
|
|
839
|
+
order_t = torch.as_tensor(order, device=device, dtype=torch.long)
|
|
840
|
+
data_sorted_t = data[..., order_t] # (B, Ci, K) aligned with ids_sorted_np
|
|
841
|
+
|
|
842
|
+
# C) bind sparse support
|
|
843
|
+
if G > 1:
|
|
844
|
+
self.bind_support_torch_multi(ids_sorted_np, device=device, dtype=dtype)
|
|
845
|
+
out_sorted = self.apply_multi(data_sorted_t, kernel) # (B, G*Co_g, K)
|
|
846
|
+
else:
|
|
847
|
+
self.bind_support_torch(ids_sorted_np, device=device, dtype=dtype)
|
|
848
|
+
out_sorted = self.apply(data_sorted_t, kernel) # (B, Co, K)
|
|
849
|
+
|
|
850
|
+
# D) unsort back to original order
|
|
851
|
+
inv_order = np.empty_like(order)
|
|
852
|
+
inv_order[order] = np.arange(order.size)
|
|
853
|
+
inv_idx = torch.as_tensor(inv_order, device=device, dtype=torch.long)
|
|
854
|
+
return out_sorted[..., inv_idx]
|
|
855
|
+
|
|
856
|
+
# -----------------------------------------------
|
|
857
|
+
# Case 2: fast path on cached geometry + mapping
|
|
858
|
+
# -----------------------------------------------
|
|
859
|
+
if self.ids_sorted_np is None:
|
|
860
|
+
if getattr(self, 'cell_ids_default', None) is not None:
|
|
861
|
+
self.ids_sorted_np = np.sort(self.cell_ids_default)
|
|
862
|
+
else:
|
|
863
|
+
raise AssertionError(
|
|
864
|
+
"No cached targets. Either pass `cell_ids` once or initialize the class with `cell_ids=`."
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
if _has_multi_bind():
|
|
868
|
+
# rebind if device/dtype changed
|
|
869
|
+
if (self.device != device) or (self.dtype != dtype):
|
|
870
|
+
self.bind_support_torch_multi(self.ids_sorted_np, device=device, dtype=dtype)
|
|
871
|
+
return self.apply_multi(data, kernel)
|
|
872
|
+
|
|
873
|
+
# single-gauge cached path
|
|
874
|
+
need_rebind = (
|
|
875
|
+
getattr(self, 'pos_safe_t', None) is None or
|
|
876
|
+
getattr(self, 'w_norm_t', None) is None or
|
|
877
|
+
self.device != device or
|
|
878
|
+
self.dtype != dtype
|
|
879
|
+
)
|
|
880
|
+
if need_rebind:
|
|
881
|
+
self.bind_support_torch(self.ids_sorted_np, device=device, dtype=dtype)
|
|
882
|
+
return self.apply(data, kernel)
|
|
883
|
+
|
|
884
|
+
def Convol_torch(self, im, ww, cell_ids=None, nside=None):
|
|
885
|
+
"""
|
|
886
|
+
Batched KERNELSZ x KERNELSZ aggregation (dispatcher).
|
|
887
|
+
|
|
888
|
+
Supports:
|
|
889
|
+
- im: Tensor (B, Ci, K) with
|
|
890
|
+
* cell_ids is None -> use cached targets (fast path)
|
|
891
|
+
* cell_ids is 1D (K,) -> one shared grid for whole batch
|
|
892
|
+
* cell_ids is 2D (B, K) -> per-sample grids, same length; returns (B, Co, K)
|
|
893
|
+
* cell_ids is list/tuple -> per-sample grids (var-length allowed)
|
|
894
|
+
- im: list/tuple of Tensors, each (Ci, K_b) with cell_ids list/tuple
|
|
895
|
+
|
|
896
|
+
Notes
|
|
897
|
+
-----
|
|
898
|
+
- Kernel shapes accepted:
|
|
899
|
+
* single/multi shared: (Ci, Co_g, P)
|
|
900
|
+
* per-gauge kernels: (G, Ci, Co_g, P)
|
|
901
|
+
The low-level _Convol_Torch will choose between apply/apply_multi
|
|
902
|
+
depending on the class state (G>1 and multi-bind present).
|
|
903
|
+
"""
|
|
904
|
+
import numpy as np
|
|
905
|
+
import torch
|
|
906
|
+
|
|
907
|
+
def _dev_dtype_like(x: torch.Tensor):
|
|
908
|
+
if not isinstance(x, torch.Tensor):
|
|
909
|
+
raise TypeError("Expected a torch.Tensor for device/dtype inference.")
|
|
910
|
+
return x.device, x.dtype
|
|
911
|
+
|
|
912
|
+
def _prepare_kernel(k: torch.Tensor, device, dtype):
|
|
913
|
+
if not isinstance(k, torch.Tensor):
|
|
914
|
+
raise TypeError("kernel (ww) must be a torch.Tensor")
|
|
915
|
+
return k.to(device=device, dtype=dtype)
|
|
916
|
+
|
|
917
|
+
def _to_np_ids(ids):
|
|
918
|
+
if isinstance(ids, torch.Tensor):
|
|
919
|
+
return ids.detach().cpu().numpy().astype(np.int64, copy=False)
|
|
920
|
+
return np.asarray(ids, dtype=np.int64)
|
|
921
|
+
|
|
922
|
+
class _NsideContext:
|
|
923
|
+
def __init__(self, obj, nside_new):
|
|
924
|
+
self.obj = obj
|
|
925
|
+
self.nside_old = obj.nside
|
|
926
|
+
self.nside_new = int(nside_new) if nside_new is not None else obj.nside
|
|
927
|
+
def __enter__(self):
|
|
928
|
+
self.obj.nside = self.nside_new
|
|
929
|
+
return self
|
|
930
|
+
def __exit__(self, exc_type, exc, tb):
|
|
931
|
+
self.obj.nside = self.nside_old
|
|
932
|
+
|
|
933
|
+
# ---------------- main dispatcher ----------------
|
|
934
|
+
if isinstance(im, torch.Tensor):
|
|
935
|
+
device, dtype = _dev_dtype_like(im)
|
|
936
|
+
kernel = _prepare_kernel(ww, device, dtype)
|
|
937
|
+
|
|
938
|
+
with _NsideContext(self, nside):
|
|
939
|
+
# (A) Fast path: no ids provided -> delegate fully to _Convol_Torch
|
|
940
|
+
if cell_ids is None:
|
|
941
|
+
return self._Convol_Torch(im, kernel, cell_ids=None)
|
|
942
|
+
|
|
943
|
+
# Normalise numpy/tensor ragged inputs
|
|
944
|
+
if isinstance(cell_ids, np.ndarray) and cell_ids.dtype == object:
|
|
945
|
+
cell_ids = list(cell_ids)
|
|
946
|
+
|
|
947
|
+
# (B) One shared grid for entire batch: 1-D ids
|
|
948
|
+
if isinstance(cell_ids, (np.ndarray, torch.Tensor)) and getattr(cell_ids, "ndim", 1) == 1:
|
|
949
|
+
return self._Convol_Torch(im, kernel, cell_ids=_to_np_ids(cell_ids))
|
|
950
|
+
|
|
951
|
+
# (C) Per-sample grids, same length: 2-D ids (B, K)
|
|
952
|
+
if isinstance(cell_ids, (np.ndarray, torch.Tensor)) and getattr(cell_ids, "ndim", 0) == 2:
|
|
953
|
+
B = im.shape[0]
|
|
954
|
+
if isinstance(cell_ids, torch.Tensor):
|
|
955
|
+
assert cell_ids.shape[0] == B, "cell_ids first dim must match batch size B"
|
|
956
|
+
ids2d = cell_ids.detach().cpu().numpy().astype(np.int64, copy=False)
|
|
957
|
+
else:
|
|
958
|
+
ids2d = np.asarray(cell_ids, dtype=np.int64)
|
|
959
|
+
assert ids2d.shape[0] == B, "cell_ids first dim must match batch size B"
|
|
960
|
+
|
|
961
|
+
outs = []
|
|
962
|
+
for b in range(B):
|
|
963
|
+
x_b = im[b:b+1] # (1, Ci, K_b)
|
|
964
|
+
ids_b = ids2d[b] # (K_b,)
|
|
965
|
+
y_b = self._Convol_Torch(x_b, kernel, cell_ids=ids_b) # (1, Co, K_b)
|
|
966
|
+
outs.append(y_b)
|
|
967
|
+
return torch.cat(outs, dim=0) # (B, Co, K)
|
|
968
|
+
|
|
969
|
+
# (D) Per-sample grids, variable length: list/tuple
|
|
970
|
+
if isinstance(cell_ids, (list, tuple)):
|
|
971
|
+
B = im.shape[0]
|
|
972
|
+
assert len(cell_ids) == B, "cell_ids list length must match batch size B"
|
|
973
|
+
outs = []
|
|
974
|
+
lengths = []
|
|
975
|
+
for b in range(B):
|
|
976
|
+
ids_b_np = _to_np_ids(cell_ids[b])
|
|
977
|
+
lengths.append(ids_b_np.size)
|
|
978
|
+
x_b = im[b:b+1] # (1, Ci, K_b)
|
|
979
|
+
y_b = self._Convol_Torch(x_b, kernel, cell_ids=ids_b_np) # (1, Co, K_b)
|
|
980
|
+
outs.append(y_b)
|
|
981
|
+
if len(set(lengths)) == 1:
|
|
982
|
+
return torch.cat(outs, dim=0) # (B, Co, K)
|
|
983
|
+
else:
|
|
984
|
+
return [y.squeeze(0) for y in outs] # list[(Co, K_b)]
|
|
985
|
+
|
|
986
|
+
raise TypeError("Unsupported type for cell_ids with tensor input.")
|
|
987
|
+
|
|
988
|
+
# Case: im is list/tuple of (Ci, K_b) tensors (var-length samples)
|
|
989
|
+
if isinstance(im, (list, tuple)):
|
|
990
|
+
assert isinstance(cell_ids, (list, tuple)) and len(cell_ids) == len(im), \
|
|
991
|
+
"When im is a list, cell_ids must be a list of same length."
|
|
992
|
+
assert len(im) > 0, "Empty list for `im`."
|
|
993
|
+
|
|
994
|
+
device, dtype = _dev_dtype_like(im[0])
|
|
995
|
+
kernel = _prepare_kernel(ww, device, dtype)
|
|
996
|
+
|
|
997
|
+
outs = []
|
|
998
|
+
with _NsideContext(self, nside):
|
|
999
|
+
lengths = []
|
|
1000
|
+
tmp = []
|
|
1001
|
+
for x_b, ids_b in zip(im, cell_ids):
|
|
1002
|
+
assert isinstance(x_b, torch.Tensor), "Each sample in `im` must be a torch.Tensor"
|
|
1003
|
+
assert x_b.device == device and x_b.dtype == dtype, "All samples must share device/dtype."
|
|
1004
|
+
x_b = x_b.unsqueeze(0) # (1, Ci, K_b)
|
|
1005
|
+
ids_b = _to_np_ids(ids_b)
|
|
1006
|
+
y_b = self._Convol_Torch(x_b, kernel, cell_ids=ids_b) # (1, Co, K_b)
|
|
1007
|
+
tmp.append(y_b)
|
|
1008
|
+
lengths.append(y_b.shape[-1])
|
|
1009
|
+
if len(set(lengths)) == 1:
|
|
1010
|
+
return torch.cat(tmp, dim=0) # (B, Co, K)
|
|
1011
|
+
else:
|
|
1012
|
+
return [y.squeeze(0) for y in tmp]
|
|
1013
|
+
|
|
1014
|
+
raise TypeError("`im` must be either a torch.Tensor (B,Ci,K) or a list of (Ci,K_b) tensors.")
|
|
1015
|
+
|
|
1016
|
+
def make_matrix(
|
|
1017
|
+
self,
|
|
1018
|
+
kernel: torch.Tensor,
|
|
1019
|
+
cell_ids=None,
|
|
1020
|
+
*,
|
|
1021
|
+
return_sparse_tensor: bool = False,
|
|
1022
|
+
chunk_k: int = 4096,
|
|
1023
|
+
):
|
|
1024
|
+
"""
|
|
1025
|
+
Build the sparse COO matrix M such that applying M to vec(data) reproduces
|
|
1026
|
+
the spherical convolution performed by Convol_torch/_Convol_Torch.
|
|
1027
|
+
|
|
1028
|
+
Supports single- and multi-gauge:
|
|
1029
|
+
- kernel shape (Ci, Co_g, P) -> shared across G gauges, output Co = G*Co_g
|
|
1030
|
+
- kernel shape (G, Ci, Co_g, P) -> per-gauge kernels, same output Co = G*Co_g
|
|
1031
|
+
|
|
1032
|
+
Parameters
|
|
1033
|
+
----------
|
|
1034
|
+
kernel : torch.Tensor
|
|
1035
|
+
(Ci, Co_g, P) or (G, Ci, Co_g, P) with P = kernel_sz**2.
|
|
1036
|
+
Must be on the device/dtype where you want the resulting matrix.
|
|
1037
|
+
cell_ids : array-like of shape (K,) or torch.Tensor, optional
|
|
1038
|
+
Target pixel IDs (NESTED if self.nest=True).
|
|
1039
|
+
If None, uses the grid already cached in the class (fast path).
|
|
1040
|
+
If provided, we prepare geometry & sparse binding for these ids.
|
|
1041
|
+
return_sparse_tensor : bool, default False
|
|
1042
|
+
If True, return a coalesced torch.sparse_coo_tensor of shape (Co*K, Ci*K).
|
|
1043
|
+
Else, return (weights, indices, shape) where:
|
|
1044
|
+
- indices is a LongTensor of shape (2, nnz) with [row; col]
|
|
1045
|
+
- weights is a Tensor of shape (nnz,)
|
|
1046
|
+
- shape is the (rows, cols) tuple
|
|
1047
|
+
chunk_k : int, default 4096
|
|
1048
|
+
Chunk size over target pixels to limit peak memory.
|
|
1049
|
+
|
|
1050
|
+
Returns
|
|
1051
|
+
-------
|
|
1052
|
+
If return_sparse_tensor:
|
|
1053
|
+
M : torch.sparse_coo_tensor of shape (Co*K, Ci*K), coalesced
|
|
1054
|
+
else:
|
|
1055
|
+
weights : torch.Tensor (nnz,)
|
|
1056
|
+
indices : torch.LongTensor (2, nnz) with [row; col]
|
|
1057
|
+
shape : tuple[int, int] (Co*K, Ci*K)
|
|
1058
|
+
|
|
1059
|
+
Notes
|
|
1060
|
+
-----
|
|
1061
|
+
- The resulting matrix implements the same interpolation-and-mixing as the
|
|
1062
|
+
GPU path (gather 4 neighbors -> normalize -> apply spatial+channel kernel),
|
|
1063
|
+
and matches the output of Convol_torch for the same (kernel, cell_ids).
|
|
1064
|
+
- For multi-gauge, rows are grouped as concatenated gauges: first all
|
|
1065
|
+
Co_g channels for gauge 0 over all K, then gauge 1, etc.
|
|
1066
|
+
"""
|
|
1067
|
+
import numpy as np
|
|
1068
|
+
import torch
|
|
1069
|
+
import healpy as hp
|
|
1070
|
+
|
|
1071
|
+
device = kernel.device
|
|
1072
|
+
k_dtype = kernel.dtype
|
|
1073
|
+
|
|
1074
|
+
# --- validate kernel & normalize shapes
|
|
1075
|
+
if kernel.dim() == 3:
|
|
1076
|
+
# shared across gauges
|
|
1077
|
+
Ci, Co_g, P = kernel.shape
|
|
1078
|
+
per_gauge = False
|
|
1079
|
+
elif kernel.dim() == 4:
|
|
1080
|
+
Gk, Ci, Co_g, P = kernel.shape
|
|
1081
|
+
per_gauge = True
|
|
1082
|
+
if hasattr(self, 'G'):
|
|
1083
|
+
assert Gk == self.G, f"kernel first dim G={Gk} must match self.G={self.G}"
|
|
1084
|
+
else:
|
|
1085
|
+
self.G = int(Gk)
|
|
1086
|
+
else:
|
|
1087
|
+
raise ValueError("kernel must be (Ci,Co_g,P) or (G,Ci,Co_g,P)")
|
|
1088
|
+
|
|
1089
|
+
assert P == self.P, f"kernel P={P} must equal kernel_sz**2={self.P}"
|
|
1090
|
+
|
|
1091
|
+
# --- geometry + binding for these ids (or use cached)
|
|
1092
|
+
def _to_np_ids(ids):
|
|
1093
|
+
if ids is None:
|
|
1094
|
+
return None
|
|
1095
|
+
if isinstance(ids, torch.Tensor):
|
|
1096
|
+
return ids.detach().cpu().numpy().astype(np.int64, copy=False).reshape(-1)
|
|
1097
|
+
return np.asarray(ids, dtype=np.int64).reshape(-1)
|
|
1098
|
+
|
|
1099
|
+
cell_ids_np = _to_np_ids(cell_ids)
|
|
1100
|
+
|
|
1101
|
+
if cell_ids_np is not None:
|
|
1102
|
+
# Step A: geometry (Torch) with the class' number of gauges
|
|
1103
|
+
G = int(getattr(self, 'G', 1))
|
|
1104
|
+
th, ph = hp.pix2ang(self.nside, cell_ids_np, nest=self.nest)
|
|
1105
|
+
self.prepare_torch(th, ph, alpha=None, G=G)
|
|
1106
|
+
|
|
1107
|
+
# Step B: bind on sorted ids, and remember K
|
|
1108
|
+
order = np.argsort(cell_ids_np)
|
|
1109
|
+
ids_sorted_np = cell_ids_np[order]
|
|
1110
|
+
K = ids_sorted_np.size
|
|
1111
|
+
|
|
1112
|
+
if G > 1:
|
|
1113
|
+
self.bind_support_torch_multi(ids_sorted_np, device=device, dtype=k_dtype)
|
|
1114
|
+
else:
|
|
1115
|
+
self.bind_support_torch(ids_sorted_np, device=device, dtype=k_dtype)
|
|
1116
|
+
else:
|
|
1117
|
+
# use cached mapping
|
|
1118
|
+
if getattr(self, 'ids_sorted_np', None) is None:
|
|
1119
|
+
raise AssertionError("No cached targets; pass `cell_ids` or init the class with `cell_ids=`.")
|
|
1120
|
+
K = self.ids_sorted_np.size
|
|
1121
|
+
# rebind to the kernel device/dtype if needed
|
|
1122
|
+
if getattr(self, 'G', 1) > 1:
|
|
1123
|
+
if (self.device != device) or (self.dtype != k_dtype):
|
|
1124
|
+
self.bind_support_torch_multi(self.ids_sorted_np, device=device, dtype=k_dtype)
|
|
1125
|
+
else:
|
|
1126
|
+
if (self.device != device) or (self.dtype != k_dtype):
|
|
1127
|
+
self.bind_support_torch(self.ids_sorted_np, device=device, dtype=k_dtype)
|
|
1128
|
+
|
|
1129
|
+
G = int(getattr(self, 'G', 1))
|
|
1130
|
+
Co_total = (G * Co_g) # output channels including gauges
|
|
1131
|
+
shape = (Co_total * K, Ci * K)
|
|
1132
|
+
|
|
1133
|
+
# --- choose mapping tensors (multi vs single)
|
|
1134
|
+
is_multi = (G > 1) and (getattr(self, 'pos_safe_t_multi', None) is not None)
|
|
1135
|
+
if is_multi:
|
|
1136
|
+
pos_all_g = self.pos_safe_t_multi.to(device=device) # (G,4,K*P)
|
|
1137
|
+
w_all_g = self.w_norm_t_multi.to(device=device, dtype=k_dtype)
|
|
1138
|
+
else:
|
|
1139
|
+
pos_all = self.pos_safe_t.to(device=device) # (4,K*P)
|
|
1140
|
+
w_all = self.w_norm_t.to(device=device, dtype=k_dtype)
|
|
1141
|
+
|
|
1142
|
+
# --- precompute channel row/col bases
|
|
1143
|
+
# rows: for (co_total, k_out) -> co_total*K + k_out
|
|
1144
|
+
# cols: for (ci, k_in) -> ci*K + k_in
|
|
1145
|
+
row_base = (torch.arange(Co_total, device=device, dtype=torch.long) * K)[:, None] # (Co_total, 1)
|
|
1146
|
+
col_base = (torch.arange(Ci, device=device, dtype=torch.long) * K)[:, None] # (Ci, 1)
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
rows_all, cols_all, vals_all = [], [], []
|
|
1150
|
+
|
|
1151
|
+
# --- helper to add one gauge block (gauge g -> Co_g*K rows)
|
|
1152
|
+
def _accumulate_for_gauge(g, pos_g, w_g, ker_g):
|
|
1153
|
+
"""
|
|
1154
|
+
pos_g : (4, K*P) long
|
|
1155
|
+
w_g : (4, K*P) float
|
|
1156
|
+
ker_g : (Ci, Co_g, P)
|
|
1157
|
+
"""
|
|
1158
|
+
# process by chunks in k to control memory
|
|
1159
|
+
for start in range(0, K, chunk_k):
|
|
1160
|
+
stop = min(start + chunk_k, K)
|
|
1161
|
+
Kb = stop - start
|
|
1162
|
+
cols_span = torch.arange(start * self.P, stop * self.P, device=device, dtype=torch.long)
|
|
1163
|
+
|
|
1164
|
+
pos = pos_g[:, cols_span].view(4, Kb, self.P) # (4, Kb, P)
|
|
1165
|
+
w = w_g[:, cols_span].view(4, Kb, self.P) # (4, Kb, P)
|
|
1166
|
+
|
|
1167
|
+
# rows_gauge: indices de lignes pour cette jauge g
|
|
1168
|
+
# Chaque jauge occupe un bloc de Co_g canaux de sortie pour CHAQUE pixel (K)
|
|
1169
|
+
# donc offset = g*Co_g
|
|
1170
|
+
rows_gauge = (torch.arange(Co_g, device=device, dtype=torch.long) + g*Co_g)[:, None] * K \
|
|
1171
|
+
+ (start + torch.arange(Kb, device=device, dtype=torch.long))[None, :]
|
|
1172
|
+
# -> shape (Co_g, Kb)
|
|
1173
|
+
rows = rows_gauge[:, :, None, None, None] # (Co_g, Kb,1,1,1)
|
|
1174
|
+
rows = rows.expand(Co_g, Kb, Ci, 4, self.P) # (Co_g, Kb, Ci, 4, P)
|
|
1175
|
+
|
|
1176
|
+
# cols: indices colonnes = (ci*K + pix)
|
|
1177
|
+
cols_pix = pos.permute(1, 0, 2) # (Kb, 4, P)
|
|
1178
|
+
cols_pix = cols_pix[None, :, None, :, :] # (1, Kb, 1, 4, P)
|
|
1179
|
+
cols = col_base + cols_pix # (Ci, Kb, 1, 4, P)
|
|
1180
|
+
cols = cols.permute(2, 1, 0, 3, 4) # (1, Kb, Ci, 4, P)
|
|
1181
|
+
cols = cols.expand(Co_g, Kb, Ci, 4, self.P)
|
|
1182
|
+
|
|
1183
|
+
# values = kernel(ci, co_g, p) * w(4,kb,p)
|
|
1184
|
+
k_exp = ker_g.permute(1, 0, 2) # (Co_g, Ci, P)
|
|
1185
|
+
k_exp = k_exp[:, None, :, None, :] # (Co_g, 1, Ci, 1, P)
|
|
1186
|
+
|
|
1187
|
+
# CORRECTION: remettre les axes de w en (Kb,4,P) avant broadcast
|
|
1188
|
+
w_exp = w.permute(1, 0, 2)[None, :, None, :, :] # (1, Kb, 1, 4, P)
|
|
1189
|
+
w_exp = w_exp.expand(Co_g, Kb, Ci, 4, self.P) # (Co_g, Kb, Ci, 4, P)
|
|
1190
|
+
|
|
1191
|
+
vals = k_exp * w_exp # (Co_g, Kb, Ci, 4, P)
|
|
1192
|
+
|
|
1193
|
+
rows_all.append(rows.reshape(-1))
|
|
1194
|
+
cols_all.append(cols.reshape(-1))
|
|
1195
|
+
vals_all.append(vals.reshape(-1))
|
|
1196
|
+
|
|
1197
|
+
|
|
1198
|
+
# --- accumulate either single- or multi-gauge
|
|
1199
|
+
if is_multi:
|
|
1200
|
+
# (a) shared kernel (Ci, Co_g, P) -> repeat over gauges
|
|
1201
|
+
if not per_gauge and kernel.dim() == 3:
|
|
1202
|
+
for g in range(G):
|
|
1203
|
+
_accumulate_for_gauge(g, pos_all_g[g], w_all_g[g], kernel.to(device=device, dtype=k_dtype))
|
|
1204
|
+
# (b) per-gauge kernel (G, Ci, Co_g, P)
|
|
1205
|
+
else:
|
|
1206
|
+
for g in range(G):
|
|
1207
|
+
_accumulate_for_gauge(g, pos_all_g[g], w_all_g[g], kernel[g].to(device=device, dtype=k_dtype))
|
|
1208
|
+
else:
|
|
1209
|
+
# G == 1 (single-gauge path)
|
|
1210
|
+
g = 0
|
|
1211
|
+
_accumulate_for_gauge(g, pos_all, w_all, kernel if kernel.dim() == 3 else kernel[0])
|
|
1212
|
+
|
|
1213
|
+
rows = torch.cat(rows_all, dim=0)
|
|
1214
|
+
cols = torch.cat(cols_all, dim=0)
|
|
1215
|
+
vals = torch.cat(vals_all, dim=0)
|
|
1216
|
+
|
|
1217
|
+
|
|
1218
|
+
indices = torch.stack([cols, rows], dim=0) # (2, nnz) invert rows/cols for foscat needs
|
|
1219
|
+
|
|
1220
|
+
if return_sparse_tensor:
|
|
1221
|
+
M = torch.sparse_coo_tensor(indices, vals, size=shape, device=device, dtype=k_dtype).coalesce()
|
|
1222
|
+
return M
|
|
1223
|
+
else:
|
|
1224
|
+
return vals, indices, shape
|
|
1225
|
+
|
|
1226
|
+
|
|
1227
|
+
def _to_numpy_1d(self, ids):
|
|
1228
|
+
"""Return a 1D numpy array of int64 for a single set of cell ids."""
|
|
1229
|
+
import numpy as np, torch
|
|
1230
|
+
if isinstance(ids, np.ndarray):
|
|
1231
|
+
return ids.reshape(-1).astype(np.int64, copy=False)
|
|
1232
|
+
if torch.is_tensor(ids):
|
|
1233
|
+
return ids.detach().cpu().to(torch.long).view(-1).numpy()
|
|
1234
|
+
# python list/tuple of ints
|
|
1235
|
+
return np.asarray(ids, dtype=np.int64).reshape(-1)
|
|
1236
|
+
|
|
1237
|
+
def _is_varlength_batch(self, ids):
|
|
1238
|
+
"""
|
|
1239
|
+
True if ids is a list/tuple of per-sample id arrays (var-length batch).
|
|
1240
|
+
False if ids is a single array/tensor of ids (shared for whole batch).
|
|
1241
|
+
"""
|
|
1242
|
+
import numpy as np, torch
|
|
1243
|
+
if isinstance(ids, (list, tuple)):
|
|
1244
|
+
return True
|
|
1245
|
+
if isinstance(ids, np.ndarray) and ids.ndim == 2:
|
|
1246
|
+
# This would be a dense (B, Npix) matrix -> NOT var-length list
|
|
1247
|
+
return False
|
|
1248
|
+
if torch.is_tensor(ids) and ids.dim() == 2:
|
|
1249
|
+
return False
|
|
1250
|
+
return False
|
|
1251
|
+
|
|
1252
|
+
def Down(self, im, cell_ids=None, nside=None,max_poll=False):
|
|
1253
|
+
"""
|
|
1254
|
+
If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
|
|
1255
|
+
If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
|
|
1256
|
+
"""
|
|
1257
|
+
if self.f is None:
|
|
1258
|
+
if self.dtype==torch.float64:
|
|
1259
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
1260
|
+
else:
|
|
1261
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
1262
|
+
|
|
1263
|
+
if cell_ids is None:
|
|
1264
|
+
dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=max_poll)
|
|
1265
|
+
return dim,cdim
|
|
1266
|
+
|
|
1267
|
+
if nside is None:
|
|
1268
|
+
nside = self.nside
|
|
1269
|
+
|
|
1270
|
+
# var-length mode: list/tuple of ids, one per sample
|
|
1271
|
+
if self._is_varlength_batch(cell_ids):
|
|
1272
|
+
outs, outs_ids = [], []
|
|
1273
|
+
B = len(cell_ids)
|
|
1274
|
+
for b in range(B):
|
|
1275
|
+
cid_b = self._to_numpy_1d(cell_ids[b])
|
|
1276
|
+
# extraire le bon échantillon d'`im`
|
|
1277
|
+
if torch.is_tensor(im):
|
|
1278
|
+
xb = im[b:b+1] # (1, C, N_b)
|
|
1279
|
+
yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
1280
|
+
outs.append(yb.squeeze(0)) # (C, N_b')
|
|
1281
|
+
else:
|
|
1282
|
+
# si im est déjà une liste de (C, N_b)
|
|
1283
|
+
xb = im[b]
|
|
1284
|
+
yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
1285
|
+
outs.append(yb.squeeze(0))
|
|
1286
|
+
outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
|
|
1287
|
+
return outs, outs_ids
|
|
1288
|
+
|
|
1289
|
+
# grille commune (un seul vecteur d'ids)
|
|
1290
|
+
cid = self._to_numpy_1d(cell_ids)
|
|
1291
|
+
return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
|
|
1292
|
+
|
|
1293
|
+
def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
|
|
1294
|
+
"""
|
|
1295
|
+
If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
|
|
1296
|
+
If they are lists (var-length per sample) -> return list[Tensor].
|
|
1297
|
+
"""
|
|
1298
|
+
if self.f is None:
|
|
1299
|
+
if self.dtype==torch.float64:
|
|
1300
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
1301
|
+
else:
|
|
1302
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
1303
|
+
|
|
1304
|
+
if cell_ids is None:
|
|
1305
|
+
dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
|
|
1306
|
+
return dim
|
|
1307
|
+
|
|
1308
|
+
if nside is None:
|
|
1309
|
+
nside = self.nside
|
|
1310
|
+
|
|
1311
|
+
# var-length: listes parallèles
|
|
1312
|
+
if self._is_varlength_batch(cell_ids):
|
|
1313
|
+
assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
|
|
1314
|
+
"In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
|
|
1315
|
+
outs = []
|
|
1316
|
+
B = len(cell_ids)
|
|
1317
|
+
for b in range(B):
|
|
1318
|
+
cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
|
|
1319
|
+
ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
|
|
1320
|
+
if torch.is_tensor(im):
|
|
1321
|
+
xb = im[b:b+1] # (1, C, N_b_coarse)
|
|
1322
|
+
yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
|
|
1323
|
+
o_cell_ids=ocid_b, force_init_index=True)
|
|
1324
|
+
outs.append(yb.squeeze(0)) # (C, N_b_fine)
|
|
1325
|
+
else:
|
|
1326
|
+
xb = im[b] # (C, N_b_coarse)
|
|
1327
|
+
yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
|
|
1328
|
+
o_cell_ids=ocid_b, force_init_index=True)
|
|
1329
|
+
outs.append(yb.squeeze(0))
|
|
1330
|
+
return outs
|
|
1331
|
+
|
|
1332
|
+
# grille commune
|
|
1333
|
+
cid = self._to_numpy_1d(cell_ids)
|
|
1334
|
+
ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
|
|
1335
|
+
return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
|
|
1336
|
+
o_cell_ids=ocid, force_init_index=True)
|
|
1337
|
+
|
|
1338
|
+
|
|
1339
|
+
def to_tensor(self,x):
|
|
1340
|
+
return torch.tensor(x,device=self.device,dtype=self.dtype)
|
|
1341
|
+
|
|
1342
|
+
def to_numpy(self,x):
|
|
1343
|
+
if isinstance(x,np.ndarray):
|
|
1344
|
+
return x
|
|
1345
|
+
return x.cpu().numpy()
|
|
1346
|
+
|