foscat 2025.9.1__py3-none-any.whl → 2025.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1346 @@
1
+ # SPDX-License-Identifier: MIT
2
+ # Author: J.-M. Delouis
3
+ import numpy as np
4
+ import healpy as hp
5
+ import foscat.scat_cov as sc
6
+ import torch
7
+
8
+ import numpy as np
9
+ import torch
10
+ import healpy as hp
11
+
12
+
13
+ class SphericalStencil:
14
+ """
15
+ GPU-accelerated spherical stencil operator for HEALPix convolutions.
16
+
17
+ This class implements three phases:
18
+ A) Geometry preparation: build local rotated stencil vectors for each target
19
+ pixel, compute HEALPix neighbor indices and interpolation weights.
20
+ B) Sparse binding: map neighbor indices/weights to available data samples
21
+ (sorted ids), and normalize weights.
22
+ C) Convolution: apply multi-channel kernels to sparse gathered data.
23
+
24
+ Once A+B are prepared, multiple convolutions (C) can be applied efficiently
25
+ on the GPU.
26
+
27
+ Parameters
28
+ ----------
29
+ nside : int
30
+ HEALPix resolution parameter.
31
+ kernel_sz : int
32
+ Size of local stencil (must be odd, e.g. 3, 5, 7).
33
+ gauge_type : str
34
+ Type of gauge :
35
+ 'cosmo' use the same definition than
36
+ https://www.aanda.org/articles/aa/abs/2022/12/aa44566-22/aa44566-22.html
37
+ 'phi' is define at the pole, could be better for earth observation not using intensivly the pole
38
+ n_gauge : float
39
+ Number of oriented gauges (Default 1).
40
+ blend : bool
41
+ Whether to blend smoothly between axisA and axisB (dual gauge).
42
+ power : float
43
+ Sharpness of blend transition (dual gauge).
44
+ nest : bool
45
+ Use nested ordering if True (default), else ring ordering.
46
+ cell_ids : np.ndarray | torch.Tensor | None
47
+ If given, initialize Step A immediately for these targets.
48
+ device : torch.device | str | None
49
+ Default device (if None, 'cuda' if available else 'cpu').
50
+ dtype : torch.dtype | None
51
+ Default dtype (float32 if None).
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ nside: int,
57
+ kernel_sz: int,
58
+ *,
59
+ nest: bool = True,
60
+ cell_ids=None,
61
+ device=None,
62
+ dtype=None,
63
+ n_gauges=1,
64
+ gauge_type='cosmo',
65
+ scat_op=None,
66
+ ):
67
+ assert kernel_sz >= 1 and int(kernel_sz) == kernel_sz
68
+ assert kernel_sz % 2 == 1, "kernel_sz must be odd"
69
+
70
+ self.nside = int(nside)
71
+ self.KERNELSZ = int(kernel_sz)
72
+ self.P = self.KERNELSZ * self.KERNELSZ
73
+
74
+ self.G = n_gauges
75
+ self.gauge_type=gauge_type
76
+
77
+ self.nest = bool(nest)
78
+ if scat_op is None:
79
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ)
80
+ else:
81
+ self.f=scat_op
82
+
83
+ # Torch defaults
84
+ if device is None:
85
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
86
+ if dtype is None:
87
+ dtype = torch.float32
88
+ self.device = torch.device(device)
89
+ self.dtype = dtype
90
+
91
+ # Geometry cache
92
+ self.Kb = None
93
+ self.idx_t = None # (4, K*P) neighbor indices
94
+ self.w_t = None # (4, K*P) interpolation weights
95
+ self.ids_sorted_np = None
96
+ self.pos_safe_t = None
97
+ self.w_norm_t = None
98
+ self.present_t = None
99
+
100
+ # Optionnel : on garde une copie des ids par défaut si fournis
101
+ self.cell_ids_default = None
102
+
103
+ # ---- Optional immediate preparation (Step A+B at init) ----
104
+ if cell_ids is not None:
105
+ # Keep a copy of the default target grid (fast-path later)
106
+ cid = np.asarray(cell_ids, dtype=np.int64).reshape(-1)
107
+ self.cell_ids_default = cid.copy()
108
+
109
+ # Step A (Torch): build geometry for this grid with G gauges
110
+ th, ph = hp.pix2ang(self.nside, cid, nest=self.nest)
111
+ self.prepare_torch(th, ph, G=self.G) # fills idx_t/_multi and w_t/_multi
112
+
113
+ # Step B (Torch): bind sparse mapping on the class device/dtype
114
+ order = np.argsort(cid)
115
+ self.ids_sorted_np = cid[order] # cache for fast-path
116
+
117
+ if self.G > 1:
118
+ # Multi-gauge binding (produces pos_safe_t_multi, w_norm_t_multi)
119
+ self.bind_support_torch_multi(
120
+ self.ids_sorted_np,
121
+ device=self.device,
122
+ dtype=self.dtype,
123
+ )
124
+ else:
125
+ # Single-gauge binding (produces pos_safe_t, w_norm_t)
126
+ self.bind_support_torch(
127
+ self.ids_sorted_np,
128
+ device=self.device,
129
+ dtype=self.dtype,
130
+ )
131
+
132
+
133
+ # ------------------------------------------------------------------
134
+ # Rotation construction in Torch
135
+ # ------------------------------------------------------------------
136
+ @staticmethod
137
+ def _rotation_total_torch(th, ph, alpha=None, G: int = 1, gauge_cosmo=True,device=None, dtype=None):
138
+ """
139
+ Build a batch of rotation matrices with *G gauges* per target.
140
+
141
+ Column-vector convention: v' = R @ v.
142
+
143
+ Parameters
144
+ ----------
145
+ th : array-like (N,)
146
+ Colatitude.
147
+ ph : array-like (N,)
148
+ Longitude.
149
+ alpha : array-like (N,) or scalar or None
150
+ Base gauge rotation angle around the local normal.
151
+ If None -> 0. For each gauge g in [0..G-1], we add g*pi/G.
152
+ G : int
153
+ Number of gauges to generate per target (>=1).
154
+ device, dtype : torch device/dtype
155
+
156
+ Returns
157
+ -------
158
+ R_tot : torch.Tensor, shape (N, G, 3, 3)
159
+ For each target i and gauge g, the matrix:
160
+ R_tot[i,g] = R_gauge(alpha[i] + g*pi/G) @ Rz(ph[i]) @ Ry(th[i])
161
+ """
162
+ assert G >= 1, "G must be >= 1"
163
+
164
+ # ---- to torch 1D
165
+ th = torch.as_tensor(th, device=device, dtype=dtype).view(-1)
166
+ ph = torch.as_tensor(ph, device=device, dtype=dtype).view(-1)
167
+ if alpha is None:
168
+ alpha = torch.zeros_like(th)
169
+ else:
170
+ alpha = torch.as_tensor(alpha, device=device, dtype=dtype).view(-1)
171
+
172
+ device = th.device
173
+ dtype = th.dtype
174
+ N = th.shape[0]
175
+
176
+ # ---- base rotation R_base = Rz(ph) @ Ry(th), shape (N,3,3)
177
+ ct, st = torch.cos(th), torch.sin(th)
178
+ cp, sp = torch.cos(ph), torch.sin(ph)
179
+
180
+ R_base = torch.zeros((N, 3, 3), device=device, dtype=dtype)
181
+ # row 0
182
+ R_base[:, 0, 0] = cp * ct
183
+ R_base[:, 0, 1] = -sp
184
+ R_base[:, 0, 2] = cp * st
185
+ # row 1
186
+ R_base[:, 1, 0] = sp * ct
187
+ R_base[:, 1, 1] = cp
188
+ R_base[:, 1, 2] = sp * st
189
+ # row 2
190
+ R_base[:, 2, 0] = -st
191
+ R_base[:, 2, 1] = 0.0
192
+ R_base[:, 2, 2] = ct
193
+
194
+ # local normal n = third column of R_base, shape (N,3)
195
+ n = R_base[:, :, 2]
196
+ n = n / torch.linalg.norm(n, dim=1, keepdim=True).clamp_min(1e-12) # safe normalize
197
+
198
+ # per-target sign: +1 if th <= pi/2 else -1
199
+ sign = torch.where(th <= (np.pi/2), torch.ones_like(th), -torch.ones_like(th)) # (N,)
200
+
201
+ # base gauge shifts (always positive)
202
+ g_shifts = torch.arange(G, device=device, dtype=dtype) * (np.pi / G) # (G,)
203
+
204
+ # broadcast with sign: (N,G)
205
+ if gauge_cosmo:
206
+ alpha_g = alpha[:, None] + sign[:, None] * g_shifts[None, :]
207
+ else:
208
+ alpha_g = alpha[:, None] + g_shifts[None, :]
209
+
210
+ ca = torch.cos(alpha_g) # (N,G)
211
+ sa = torch.sin(alpha_g) # (N,G)
212
+
213
+ # ---- expand normal to (N,G,3)
214
+ n_g = n[:, None, :].expand(N, G, 3) # (N,G,3)
215
+ nx, ny, nz = n_g[..., 0], n_g[..., 1], n_g[..., 2]
216
+
217
+ # skew-symmetric K(n_g), shape (N,G,3,3)
218
+ K = torch.zeros((N, G, 3, 3), device=device, dtype=dtype)
219
+ K[..., 0, 1] = -nz; K[..., 0, 2] = ny
220
+ K[..., 1, 0] = nz; K[..., 1, 2] = -nx
221
+ K[..., 2, 0] = -ny; K[..., 2, 1] = nx
222
+
223
+ # outer(n,n) and identity
224
+ outer = n_g.unsqueeze(-1) * n_g.unsqueeze(-2) # (N,G,3,3)
225
+ I = torch.eye(3, device=device, dtype=dtype).view(1,1,3,3).expand(N, G, 3, 3)
226
+
227
+ # ---- Rodrigues per gauge: R_gauge(N,G,3,3)
228
+ R_gauge = I * ca.view(N, G, 1, 1) + K * sa.view(N, G, 1, 1) + \
229
+ outer * (1.0 - ca).view(N, G, 1, 1)
230
+
231
+ # ---- broadcast multiply with base: R_base_g(N,G,3,3)
232
+ R_base_g = R_base.unsqueeze(1).expand(N, G, 3, 3)
233
+ R_tot = torch.matmul(R_gauge, R_base_g) # (N,G,3,3)
234
+ return R_tot
235
+
236
+ # ------------------------------------------------------------------
237
+ # Torch-based get_interp_weights wrapper
238
+ # ------------------------------------------------------------------
239
+ @staticmethod
240
+ def get_interp_weights_from_vec_torch(
241
+ nside: int,
242
+ vec,
243
+ *,
244
+ nest: bool = True,
245
+ device=None,
246
+ dtype=None,
247
+ chunk_size=1_000_000,
248
+ ):
249
+ """
250
+ Torch wrapper for healpy.get_interp_weights using input vectors.
251
+
252
+ Parameters
253
+ ----------
254
+ nside : int
255
+ HEALPix resolution.
256
+ vec : torch.Tensor (...,3)
257
+ Direction vectors (not necessarily normalized).
258
+ nest : bool
259
+ Nested ordering if True (default).
260
+ device, dtype : Torch device/dtype.
261
+ chunk_size : int
262
+ Number of points per healpy call on CPU.
263
+
264
+ Returns
265
+ -------
266
+ idx_t : LongTensor (4, *leading)
267
+ w_t : Tensor (4, *leading)
268
+ """
269
+ if not isinstance(vec, torch.Tensor):
270
+ vec = torch.as_tensor(vec, device=device, dtype=dtype)
271
+ else:
272
+ device = vec.device if device is None else device
273
+ dtype = vec.dtype if dtype is None else dtype
274
+ vec = vec.to(device=device, dtype=dtype)
275
+
276
+ orig_shape = vec.shape[:-1]
277
+ M = int(np.prod(orig_shape)) if len(orig_shape) else 1
278
+ v = vec.reshape(M, 3)
279
+
280
+ eps = torch.finfo(vec.dtype).eps
281
+ r = torch.linalg.norm(v, dim=1, keepdim=True).clamp_min(eps)
282
+ v_unit = v / r
283
+ x, y, z = v_unit[:, 0], v_unit[:, 1], v_unit[:, 2]
284
+
285
+ theta = torch.acos(z.clamp(-1.0, 1.0))
286
+ phi = torch.atan2(y, x)
287
+ two_pi = torch.tensor(2*np.pi, device=device, dtype=dtype)
288
+ phi = (phi % two_pi)
289
+
290
+ theta_np = theta.detach().cpu().numpy()
291
+ phi_np = phi.detach().cpu().numpy()
292
+
293
+ idx_accum, w_accum = [], []
294
+ for start in range(0, M, chunk_size):
295
+ stop = min(start + chunk_size, M)
296
+ t_chunk, p_chunk = theta_np[start:stop], phi_np[start:stop]
297
+ idx_np, w_np = hp.get_interp_weights(nside, t_chunk, p_chunk, nest=nest)
298
+ idx_accum.append(idx_np)
299
+ w_accum.append(w_np)
300
+
301
+ idx_np_all = np.concatenate(idx_accum, axis=1) if len(idx_accum) > 1 else idx_accum[0]
302
+ w_np_all = np.concatenate(w_accum, axis=1) if len(w_accum) > 1 else w_accum[0]
303
+
304
+ idx_t = torch.as_tensor(idx_np_all, device=device, dtype=torch.long)
305
+ w_t = torch.as_tensor(w_np_all, device=device, dtype=dtype)
306
+
307
+ if len(orig_shape):
308
+ idx_t = idx_t.view(4, *orig_shape)
309
+ w_t = w_t.view(4, *orig_shape)
310
+
311
+ return idx_t, w_t
312
+
313
+ # ------------------------------------------------------------------
314
+ # Step A: geometry preparation fully in Torch
315
+ # ------------------------------------------------------------------
316
+ def prepare_torch(self, th, ph, alpha=None, G: int = 1):
317
+ """
318
+ Prepare rotated stencil and HEALPix neighbors/weights in Torch for *G gauges*.
319
+
320
+ Parameters
321
+ ----------
322
+ th, ph : array-like, shape (K,)
323
+ Target colatitudes/longitudes.
324
+ alpha : array-like (K,) or scalar or None
325
+ Base gauge angle about the local normal at each target. If None -> 0.
326
+ For each gauge g in [0..G-1], the effective angle is alpha + g*pi/G.
327
+ G : int (>=1)
328
+ Number of gauges to generate per target.
329
+
330
+ Side effects
331
+ ------------
332
+ Sets:
333
+ - self.Kb = K
334
+ - self.G = G
335
+ - self.idx_t_multi : (G, 4, K*P) LongTensor (neighbors per gauge)
336
+ - self.w_t_multi : (G, 4, K*P) Tensor (weights per gauge)
337
+ - For backward compat when G==1:
338
+ self.idx_t : (4, K*P)
339
+ self.w_t : (4, K*P)
340
+
341
+ Returns
342
+ -------
343
+ idx_t_multi : torch.LongTensor, shape (G, 4, K*P)
344
+ w_t_multi : torch.Tensor, shape (G, 4, K*P)
345
+ """
346
+ # --- sanitize inputs on CPU (angles) then use class device/dtype
347
+ th = np.asarray(th, float).reshape(-1)
348
+ ph = np.asarray(ph, float).reshape(-1)
349
+ K = th.size
350
+ self.Kb = K
351
+ self.G = int(G)
352
+ assert self.G >= 1, "G must be >= 1"
353
+
354
+ # --- build the local (P,3) stencil once on device
355
+ P = self.P
356
+ vec_np = np.zeros((P, 3), dtype=float)
357
+ grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2) / self.nside
358
+ vec_np[:, 0] = np.tile(grid, self.KERNELSZ)
359
+ vec_np[:, 1] = np.repeat(grid, self.KERNELSZ)
360
+ vec_np[:, 2] = 1.0 - np.sqrt(vec_np[:, 0]**2 + vec_np[:, 1]**2)
361
+ vec_t = torch.as_tensor(vec_np, device=self.device, dtype=self.dtype) # (P,3)
362
+
363
+ # --- rotation matrices for all targets & gauges: (K,G,3,3)
364
+ if alpha is None:
365
+ if self.gauge_type=='cosmo':
366
+ alpha=2*((th>np.pi/2)-0.5)*ph
367
+ else:
368
+ alpha=0.0*th
369
+
370
+ R_t = self._rotation_total_torch(
371
+ th, ph, alpha, G=self.G, gauge_cosmo=(self.gauge_type=='cosmo'),
372
+ device=self.device, dtype=self.dtype
373
+ ) # shape (K,G,3,3)
374
+
375
+ # --- rotate stencil for each (target, gauge): (K,G,P,3)
376
+ # einsum over local stencil (P,3) with rotation (K,G,3,3)
377
+ rotated = torch.einsum('kgij,pj->kgpi', R_t, vec_t) # (K,G,P,3)
378
+
379
+ # --- query HEALPix (neighbors+weights) in one call over (K*G*P)
380
+ rotated_flat = rotated.reshape(-1, 3) # (K*G*P, 3)
381
+ idx_t, w_t = self.get_interp_weights_from_vec_torch(
382
+ self.nside,
383
+ rotated_flat,
384
+ nest=self.nest,
385
+ device=self.device,
386
+ dtype=self.dtype,
387
+ ) # each (4, K*G*P)
388
+
389
+ # --- reshape back to split gauges:
390
+ # current: (4, K*G*P) -> (4, K, G, P) -> (G, 4, K, P) -> (G, 4, K*P)
391
+ idx_t = idx_t.view(4, K, self.G, P).permute(2, 0, 1, 3).reshape(self.G, 4, K*P)
392
+ w_t = w_t.view(4, K, self.G, P).permute(2, 0, 1, 3).reshape(self.G, 4, K*P)
393
+
394
+ # --- cache multi-gauge versions
395
+ self.idx_t_multi = idx_t # (G, 4, K*P)
396
+ self.w_t_multi = w_t # (G, 4, K*P)
397
+
398
+ # --- backward compatibility: when G==1, also fill single-gauge fields
399
+ if self.G == 1:
400
+ self.idx_t = idx_t[0] # (4, K*P)
401
+ self.w_t = w_t[0] # (4, K*P)
402
+ else:
403
+ # when multi-gauge, you can pick a default (e.g., gauge 0) if legacy code asks
404
+ # but better to adapt bind/apply to consume the multi-gauge tensors.
405
+ self.idx_t = None
406
+ self.w_t = None
407
+
408
+ return self.idx_t_multi, self.w_t_multi
409
+
410
+ def bind_support_torch_multi(self, ids_sorted_np, *, device=None, dtype=None):
411
+ """
412
+ Multi-gauge sparse binding (Step B) AVEC logique 'domaine réduit':
413
+ - poids des voisins hors domaine mis à 0
414
+ - renormalisation par colonne à 1
415
+ - si colonne vide: fallback sur le pixel cible (centre du stencil)
416
+
417
+ Produit:
418
+ self.pos_safe_t_multi : (G, 4, K*P)
419
+ self.w_norm_t_multi : (G, 4, K*P)
420
+ self.present_t_multi : (G, 4, K*P)
421
+ """
422
+ assert hasattr(self, 'idx_t_multi') and self.idx_t_multi is not None, \
423
+ "Call prepare_torch(..., G>0) before bind_support_torch_multi(...)"
424
+ assert hasattr(self, 'w_t_multi') and self.w_t_multi is not None
425
+
426
+ if device is None: device = self.device
427
+ if dtype is None: dtype = self.dtype
428
+
429
+ self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64).reshape(-1)
430
+ ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
431
+
432
+ G, _, M = self.idx_t_multi.shape
433
+ K = self.Kb
434
+ P = self.P
435
+ assert M == K*P, "idx_t_multi second axis must have K*P columns"
436
+
437
+ # index du centre du stencil (en flatten P)
438
+ p_ref = (self.KERNELSZ // 2) * (self.KERNELSZ + 1) # ex. 5 -> 12
439
+
440
+ pos_list, present_list, wnorm_list = [], [], []
441
+
442
+ for g in range(G):
443
+ idx = self.idx_t_multi[g].to(device=device, dtype=torch.long) # (4, M)
444
+ w = self.w_t_multi[g].to(device=device, dtype=dtype) # (4, M)
445
+
446
+ # positions dans ids_sorted
447
+ pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
448
+ in_range = pos < ids_sorted.numel()
449
+ cmp_vals = torch.full_like(idx, -1)
450
+ cmp_vals[in_range] = ids_sorted[pos[in_range]]
451
+ present = (cmp_vals == idx) # (4, M) bool
452
+
453
+ # Colonnes sans AUCUN voisin présent
454
+ empty_cols = ~present.any(dim=0) # (M,)
455
+ if empty_cols.any():
456
+ p_ref = (self.KERNELSZ // 2) * (self.KERNELSZ + 1)
457
+ k_id = torch.div(torch.arange(M, device=device), P, rounding_mode='floor') # (M,)
458
+ ref_cols = k_id * P + p_ref
459
+ src = ref_cols[empty_cols]
460
+
461
+ # copie idx/w de la colonne 'centre'
462
+ idx[:, empty_cols] = idx[:, src]
463
+ w[:, empty_cols] = w[:, src]
464
+
465
+ # --- Recompute presence/pos safely on those columns
466
+ idx_e = idx[:, empty_cols].reshape(-1) # (4*M_empty,)
467
+ pos_e = torch.searchsorted(ids_sorted, idx_e) # (4*M_empty,)
468
+ valid_e = pos_e < ids_sorted.numel()
469
+ pos_e_clipped = pos_e.clamp_max(max(ids_sorted.numel()-1, 0)).to(torch.long)
470
+ cmp_e = ids_sorted[pos_e_clipped]
471
+ present_e = valid_e & (cmp_e == idx_e) # (4*M_empty,)
472
+
473
+ present[:, empty_cols] = present_e.view(4, -1)
474
+ pos[:, empty_cols] = pos_e_clipped.view(4, -1)
475
+
476
+ # Met à zéro les poids absents puis renormalise à 1 par colonne
477
+ w = w * present
478
+ colsum = w.sum(dim=0, keepdim=True)
479
+ zero_cols = (colsum == 0)
480
+ if zero_cols.any():
481
+ w[0, zero_cols[0]] = present[0, zero_cols[0]].to(w.dtype)
482
+ colsum = w.sum(dim=0, keepdim=True)
483
+ w_norm = w / colsum.clamp_min(1e-12)
484
+
485
+ pos_safe = torch.where(present, pos, torch.zeros_like(pos))
486
+
487
+ pos_list.append(pos_safe)
488
+ present_list.append(present)
489
+ wnorm_list.append(w_norm)
490
+
491
+ self.pos_safe_t_multi = torch.stack(pos_list, dim=0) # (G, 4, M)
492
+ self.present_t_multi = torch.stack(present_list, dim=0) # (G, 4, M)
493
+ self.w_norm_t_multi = torch.stack(wnorm_list, dim=0) # (G, 4, M)
494
+
495
+ # miroir device/dtype runtime
496
+ self.device = device
497
+ self.dtype = dtype
498
+
499
+ def bind_support_torch(self, ids_sorted_np, *, device=None, dtype=None):
500
+ """
501
+ Single-gauge sparse binding (Step B) AVEC logique 'domaine réduit':
502
+ - poids des voisins hors domaine mis à 0
503
+ - renormalisation par colonne à 1
504
+ - si colonne vide: fallback sur le pixel cible (centre du stencil)
505
+ """
506
+ if device is None:
507
+ device = self.device
508
+ if dtype is None:
509
+ dtype = self.dtype
510
+
511
+ self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64)
512
+ ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
513
+
514
+ idx = self.idx_t.to(device=device, dtype=torch.long) # (4, K*P)
515
+ w = self.w_t.to(device=device, dtype=dtype) # (4, K*P)
516
+
517
+ K = self.Kb
518
+ P = self.P
519
+ M = K * P
520
+
521
+ # positions dans ids_sorted
522
+ pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
523
+ in_range = pos < ids_sorted.shape[0]
524
+ cmp_vals = torch.full_like(idx, -1)
525
+ cmp_vals[in_range] = ids_sorted[pos[in_range]]
526
+ present = (cmp_vals == idx) # (4, M)
527
+
528
+ # Fallback colonnes vides -> centre du stencil
529
+ p_ref = (self.KERNELSZ // 2) * (self.KERNELSZ + 1)
530
+ empty_cols = ~present.any(dim=0) # (M,)
531
+ if empty_cols.any():
532
+ k_id = torch.div(torch.arange(M, device=device), P, rounding_mode='floor') # (M,)
533
+ ref_cols = k_id * P + p_ref
534
+ src = ref_cols[empty_cols]
535
+
536
+ # copie idx/w de la colonne 'centre'
537
+ idx[:, empty_cols] = idx[:, src]
538
+ w[:, empty_cols] = w[:, src]
539
+
540
+ # --- Recompute presence/pos safely on those columns
541
+ idx_e = idx[:, empty_cols].reshape(-1) # (4*M_empty,)
542
+ pos_e = torch.searchsorted(ids_sorted, idx_e) # (4*M_empty,)
543
+ # valid positions strictly inside [0, len)
544
+ valid_e = pos_e < ids_sorted.numel()
545
+ pos_e_clipped = pos_e.clamp_max(max(ids_sorted.numel()-1, 0)).to(torch.long)
546
+ cmp_e = ids_sorted[pos_e_clipped]
547
+ present_e = valid_e & (cmp_e == idx_e) # (4*M_empty,)
548
+
549
+ # reshape back
550
+ present[:, empty_cols] = present_e.view(4, -1)
551
+ pos[:, empty_cols] = pos_e_clipped.view(4, -1)
552
+
553
+ # Zéro poids absents + renormalisation à 1
554
+ w = w * present
555
+ colsum = w.sum(dim=0, keepdim=True)
556
+ zero_cols = (colsum == 0)
557
+ if zero_cols.any():
558
+ # force 1 sur la première ligne disponible (ici ligne 0)
559
+ w[0, zero_cols[0]] = present[0, zero_cols[0]].to(w.dtype)
560
+ colsum = w.sum(dim=0, keepdim=True)
561
+ w_norm = w / colsum.clamp_min(1e-12)
562
+
563
+ self.pos_safe_t = torch.where(present, pos, torch.zeros_like(pos))
564
+ self.w_norm_t = w_norm
565
+ self.present_t = present
566
+
567
+ self.device = device
568
+ self.dtype = dtype
569
+
570
+
571
+ '''
572
+ def bind_support_torch_multi(self, ids_sorted_np, *, device=None, dtype=None):
573
+ """
574
+ Multi-gauge sparse binding (Step B).
575
+ Uses self.idx_t_multi / self.w_t_multi prepared by prepare_torch(..., G>1)
576
+ and builds, for each gauge g, (pos_safe, w_norm, present).
577
+
578
+ Parameters
579
+ ----------
580
+ ids_sorted_np : np.ndarray (K,)
581
+ Sorted pixel ids for available samples (matches the last axis of your data).
582
+ device, dtype : torch device/dtype for the produced mapping tensors.
583
+
584
+ Side effects
585
+ ------------
586
+ Sets:
587
+ - self.ids_sorted_np : (K,)
588
+ - self.pos_safe_t_multi : (G, 4, K*P) LongTensor
589
+ - self.w_norm_t_multi : (G, 4, K*P) Tensor
590
+ - self.present_t_multi : (G, 4, K*P) BoolTensor
591
+ - (and mirrors device/dtype in self.device/self.dtype)
592
+ """
593
+ assert hasattr(self, 'idx_t_multi') and self.idx_t_multi is not None, \
594
+ "Call prepare_torch(..., G>0) before bind_support_torch_multi(...)"
595
+ assert hasattr(self, 'w_t_multi') and self.w_t_multi is not None
596
+
597
+ if device is None: device = self.device
598
+ if dtype is None: dtype = self.dtype
599
+
600
+ self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64).reshape(-1)
601
+ ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
602
+
603
+ G, _, M = self.idx_t_multi.shape
604
+ K = self.Kb
605
+ P = self.P
606
+ assert M == K*P, "idx_t_multi second axis must have K*P columns"
607
+
608
+ pos_list, present_list, wnorm_list = [], [], []
609
+
610
+ for g in range(G):
611
+ idx = self.idx_t_multi[g].to(device=device, dtype=torch.long) # (4, M)
612
+ w = self.w_t_multi[g].to(device=device, dtype=dtype) # (4, M)
613
+
614
+ pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
615
+ in_range = pos < ids_sorted.numel()
616
+ cmp_vals = torch.full_like(idx, -1)
617
+ cmp_vals[in_range] = ids_sorted[pos[in_range]]
618
+ present = (cmp_vals == idx)
619
+
620
+ # normalize weights per column after masking
621
+ w = w * present
622
+ colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
623
+ w_norm = w / colsum
624
+
625
+ pos_safe = torch.where(present, pos, torch.zeros_like(pos))
626
+
627
+ pos_list.append(pos_safe)
628
+ present_list.append(present)
629
+ wnorm_list.append(w_norm)
630
+
631
+ self.pos_safe_t_multi = torch.stack(pos_list, dim=0) # (G, 4, M)
632
+ self.present_t_multi = torch.stack(present_list, dim=0) # (G, 4, M)
633
+ self.w_norm_t_multi = torch.stack(wnorm_list, dim=0) # (G, 4, M)
634
+
635
+ # mirror runtime placement
636
+ self.device = device
637
+ self.dtype = dtype
638
+
639
+ # ------------------------------------------------------------------
640
+ # Step B: bind support Torch
641
+ # ------------------------------------------------------------------
642
+ def bind_support_torch(self, ids_sorted_np, *, device=None, dtype=None):
643
+ """
644
+ Map HEALPix neighbor indices (from Step A) to actual data samples
645
+ sorted by pixel id. Produces pos_safe and normalized weights.
646
+
647
+ Parameters
648
+ ----------
649
+ ids_sorted_np : np.ndarray (K,)
650
+ Sorted pixel ids for available data.
651
+ device, dtype : Torch device/dtype for results.
652
+ """
653
+ if device is None:
654
+ device = self.device
655
+ if dtype is None:
656
+ dtype = self.dtype
657
+
658
+ self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64)
659
+ ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
660
+
661
+ idx = self.idx_t.to(device=device, dtype=torch.long)
662
+ w = self.w_t.to(device=device, dtype=dtype)
663
+
664
+ M = self.Kb * self.P
665
+ idx = idx.view(4, M)
666
+ w = w.view(4, M)
667
+
668
+ pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
669
+ in_range = pos < ids_sorted.shape[0]
670
+ cmp_vals = torch.full_like(idx, -1)
671
+ cmp_vals[in_range] = ids_sorted[pos[in_range]]
672
+ present = (cmp_vals == idx)
673
+
674
+ w = w * present
675
+ colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
676
+ w_norm = w / colsum
677
+
678
+ self.pos_safe_t = torch.where(present, pos, torch.zeros_like(pos))
679
+ self.w_norm_t = w_norm
680
+ self.present_t = present
681
+ self.device = device
682
+ self.dtype = dtype
683
+ '''
684
+ # ------------------------------------------------------------------
685
+ # Step C: apply convolution (already Torch in your code)
686
+ # ------------------------------------------------------------------
687
+ def apply_multi(self, data_sorted_t: torch.Tensor, kernel_t: torch.Tensor):
688
+ """
689
+ Apply multi-gauge convolution.
690
+
691
+ Inputs
692
+ ------
693
+ data_sorted_t : (B, Ci, K) torch.Tensor on self.device/self.dtype
694
+ kernel_t : either
695
+ - (Ci, Co_g, P) : shared kernel for all gauges
696
+ - (G, Ci, Co_g, P) : per-gauge kernels
697
+
698
+ Returns
699
+ -------
700
+ out : (B, G*Co_g, K) torch.Tensor
701
+ """
702
+ assert hasattr(self, 'pos_safe_t_multi') and self.pos_safe_t_multi is not None, \
703
+ "Call bind_support_torch_multi(...) before apply_multi(...)"
704
+ B, Ci, K = data_sorted_t.shape
705
+ G, _, M = self.pos_safe_t_multi.shape
706
+ assert M == K * self.P
707
+
708
+ # normalize kernel to per-gauge
709
+ if kernel_t.dim() == 3:
710
+ Ci_k, Co_g, P = kernel_t.shape
711
+ assert Ci_k == Ci and P == self.P
712
+ kernel_g = kernel_t[None, ...].expand(G, -1, -1, -1) # (G, Ci, Co_g, P)
713
+ elif kernel_t.dim() == 4:
714
+ Gk, Ci_k, Co_g, P = kernel_t.shape
715
+ assert Gk == G and Ci_k == Ci and P == self.P
716
+ kernel_g = kernel_t
717
+ else:
718
+ raise ValueError("kernel_t must be (Ci,Co_g,P) or (G,Ci,Co_g,P)")
719
+
720
+ outs = []
721
+ for g in range(G):
722
+ pos_safe = self.pos_safe_t_multi[g] # (4, K*P)
723
+ w_norm = self.w_norm_t_multi[g] # (4, K*P)
724
+
725
+ # gather four neighbors then weight -> (B,Ci,K,P)
726
+ vals_g = []
727
+ for j in range(4):
728
+ vj = data_sorted_t.index_select(2, pos_safe[j].reshape(-1)) # (B,Ci,K*P)
729
+ vj = vj.view(B, Ci, K, self.P)
730
+ vals_g.append(vj * w_norm[j].view(1, 1, K, self.P))
731
+ tmp = sum(vals_g) # (B,Ci,K,P)
732
+
733
+ # spatial+channel mixing with kernel of this gauge -> (B,Co_g,K)
734
+ yg = torch.einsum('bckp,cop->bok', tmp, kernel_g[g])
735
+ outs.append(yg)
736
+
737
+ # concat the gauges along channel dimension: (B, G*Co_g, K)
738
+ return torch.cat(outs, dim=1)
739
+
740
+ def apply(self, data_sorted_t, kernel_t):
741
+ """
742
+ Apply the (Ci,Co,P) kernel to batched sparse data (B,Ci,K)
743
+ using precomputed pos_safe and w_norm. Runs fully on GPU.
744
+
745
+ Parameters
746
+ ----------
747
+ data_sorted_t : torch.Tensor (B,Ci,K)
748
+ Input data aligned with ids_sorted.
749
+ kernel_t : torch.Tensor (Ci,Co,P)
750
+ Convolution kernel.
751
+
752
+ Returns
753
+ -------
754
+ out : torch.Tensor (B,Co,K)
755
+ """
756
+ assert self.pos_safe_t is not None and self.w_norm_t is not None
757
+ B, Ci, K = data_sorted_t.shape
758
+ Ci_k, Co, P = kernel_t.shape
759
+ assert Ci_k == Ci and P == self.P
760
+
761
+ vals = []
762
+ for j in range(4):
763
+ vj = data_sorted_t.index_select(2, self.pos_safe_t[j].reshape(-1))
764
+ vj = vj.view(B, Ci, K, P)
765
+ vals.append(vj * self.w_norm_t[j].view(1, 1, K, P))
766
+ tmp = sum(vals) # (B,Ci,K,P)
767
+
768
+ out = torch.einsum('bckp,cop->bok', tmp, kernel_t)
769
+ return out
770
+
771
+ def _Convol_Torch(self, data: torch.Tensor, kernel: torch.Tensor, cell_ids=None) -> torch.Tensor:
772
+ """
773
+ Convenience entry point with automatic single- or multi-gauge dispatch.
774
+
775
+ Behavior
776
+ --------
777
+ - If `cell_ids is None`: use cached geometry (prepare_torch) and sparse mapping
778
+ (bind_support_torch or bind_support_torch_multi) already stored in the class,
779
+ re-binding Step-B to `data`'s device/dtype when needed, then apply.
780
+ - If `cell_ids` is provided: compute geometry + sparse mapping for these cells
781
+ using the class' gauge setup (including the number of gauges G prepared by
782
+ `prepare_torch(..., G)`), reorder `data` to match the sorted ids, apply
783
+ (single or multi), and finally unsort to the original `cell_ids` order.
784
+
785
+ Parameters
786
+ ----------
787
+ data : (B, Ci, K) torch.float
788
+ Sparse map values. Last axis K must equal the number of target pixels.
789
+ kernel : torch.Tensor
790
+ - Single-gauge path: (Ci, Co, P) where P = kernel_sz**2.
791
+ - Multi-gauge path: (Ci, Co_g, P) shared kernel for all gauges, OR
792
+ (G, Ci, Co_g, P) per-gauge kernels.
793
+ The output channels will be Co (single) or G*Co_g (multi).
794
+ cell_ids : Optional[np.ndarray | torch.Tensor], shape (K,)
795
+ Target HEALPix pixels. If None, re-use the class' cached targets.
796
+
797
+ Returns
798
+ -------
799
+ out : torch.Tensor, shape (B, Co, K)
800
+ Co = Co (single gauge) or Co = G*Co_g (multi-gauge).
801
+ """
802
+ assert isinstance(data, torch.Tensor) and isinstance(kernel, torch.Tensor), \
803
+ "data and kernel must be torch.Tensors"
804
+ device = data.device
805
+ dtype = data.dtype
806
+
807
+ B, Ci, K_data = data.shape
808
+ P = self.P
809
+ P_k = kernel.shape[-1]
810
+ assert P_k == P, f"kernel P={P_k} must equal kernel_sz**2 = {P}"
811
+
812
+ def _to_np_1d(ids):
813
+ if isinstance(ids, torch.Tensor):
814
+ return ids.detach().cpu().numpy().astype(np.int64, copy=False)
815
+ return np.asarray(ids, dtype=np.int64).reshape(-1)
816
+
817
+ def _has_multi_bind():
818
+ return (getattr(self, 'G', 1) > 1 and
819
+ getattr(self, 'pos_safe_t_multi', None) is not None and
820
+ getattr(self, 'w_norm_t_multi', None) is not None)
821
+
822
+ # ----------------------------
823
+ # Case 1: new target ids given
824
+ # ----------------------------
825
+ if cell_ids is not None:
826
+ cell_ids_np = _to_np_1d(cell_ids)
827
+
828
+ # A) geometry with class' G (defaults to 1 if not set)
829
+ G = getattr(self, 'G', 1)
830
+ th, ph = hp.pix2ang(self.nside, cell_ids_np, nest=self.nest)
831
+ self.prepare_torch(th, ph, alpha=None, G=G) # fills idx_t/_multi, w_t/_multi
832
+
833
+ # B) sort ids and reorder data accordingly
834
+ order = np.argsort(cell_ids_np)
835
+ ids_sorted_np = cell_ids_np[order]
836
+ assert K_data == ids_sorted_np.size, \
837
+ "data last dimension must equal number of provided cell_ids"
838
+
839
+ order_t = torch.as_tensor(order, device=device, dtype=torch.long)
840
+ data_sorted_t = data[..., order_t] # (B, Ci, K) aligned with ids_sorted_np
841
+
842
+ # C) bind sparse support
843
+ if G > 1:
844
+ self.bind_support_torch_multi(ids_sorted_np, device=device, dtype=dtype)
845
+ out_sorted = self.apply_multi(data_sorted_t, kernel) # (B, G*Co_g, K)
846
+ else:
847
+ self.bind_support_torch(ids_sorted_np, device=device, dtype=dtype)
848
+ out_sorted = self.apply(data_sorted_t, kernel) # (B, Co, K)
849
+
850
+ # D) unsort back to original order
851
+ inv_order = np.empty_like(order)
852
+ inv_order[order] = np.arange(order.size)
853
+ inv_idx = torch.as_tensor(inv_order, device=device, dtype=torch.long)
854
+ return out_sorted[..., inv_idx]
855
+
856
+ # -----------------------------------------------
857
+ # Case 2: fast path on cached geometry + mapping
858
+ # -----------------------------------------------
859
+ if self.ids_sorted_np is None:
860
+ if getattr(self, 'cell_ids_default', None) is not None:
861
+ self.ids_sorted_np = np.sort(self.cell_ids_default)
862
+ else:
863
+ raise AssertionError(
864
+ "No cached targets. Either pass `cell_ids` once or initialize the class with `cell_ids=`."
865
+ )
866
+
867
+ if _has_multi_bind():
868
+ # rebind if device/dtype changed
869
+ if (self.device != device) or (self.dtype != dtype):
870
+ self.bind_support_torch_multi(self.ids_sorted_np, device=device, dtype=dtype)
871
+ return self.apply_multi(data, kernel)
872
+
873
+ # single-gauge cached path
874
+ need_rebind = (
875
+ getattr(self, 'pos_safe_t', None) is None or
876
+ getattr(self, 'w_norm_t', None) is None or
877
+ self.device != device or
878
+ self.dtype != dtype
879
+ )
880
+ if need_rebind:
881
+ self.bind_support_torch(self.ids_sorted_np, device=device, dtype=dtype)
882
+ return self.apply(data, kernel)
883
+
884
+ def Convol_torch(self, im, ww, cell_ids=None, nside=None):
885
+ """
886
+ Batched KERNELSZ x KERNELSZ aggregation (dispatcher).
887
+
888
+ Supports:
889
+ - im: Tensor (B, Ci, K) with
890
+ * cell_ids is None -> use cached targets (fast path)
891
+ * cell_ids is 1D (K,) -> one shared grid for whole batch
892
+ * cell_ids is 2D (B, K) -> per-sample grids, same length; returns (B, Co, K)
893
+ * cell_ids is list/tuple -> per-sample grids (var-length allowed)
894
+ - im: list/tuple of Tensors, each (Ci, K_b) with cell_ids list/tuple
895
+
896
+ Notes
897
+ -----
898
+ - Kernel shapes accepted:
899
+ * single/multi shared: (Ci, Co_g, P)
900
+ * per-gauge kernels: (G, Ci, Co_g, P)
901
+ The low-level _Convol_Torch will choose between apply/apply_multi
902
+ depending on the class state (G>1 and multi-bind present).
903
+ """
904
+ import numpy as np
905
+ import torch
906
+
907
+ def _dev_dtype_like(x: torch.Tensor):
908
+ if not isinstance(x, torch.Tensor):
909
+ raise TypeError("Expected a torch.Tensor for device/dtype inference.")
910
+ return x.device, x.dtype
911
+
912
+ def _prepare_kernel(k: torch.Tensor, device, dtype):
913
+ if not isinstance(k, torch.Tensor):
914
+ raise TypeError("kernel (ww) must be a torch.Tensor")
915
+ return k.to(device=device, dtype=dtype)
916
+
917
+ def _to_np_ids(ids):
918
+ if isinstance(ids, torch.Tensor):
919
+ return ids.detach().cpu().numpy().astype(np.int64, copy=False)
920
+ return np.asarray(ids, dtype=np.int64)
921
+
922
+ class _NsideContext:
923
+ def __init__(self, obj, nside_new):
924
+ self.obj = obj
925
+ self.nside_old = obj.nside
926
+ self.nside_new = int(nside_new) if nside_new is not None else obj.nside
927
+ def __enter__(self):
928
+ self.obj.nside = self.nside_new
929
+ return self
930
+ def __exit__(self, exc_type, exc, tb):
931
+ self.obj.nside = self.nside_old
932
+
933
+ # ---------------- main dispatcher ----------------
934
+ if isinstance(im, torch.Tensor):
935
+ device, dtype = _dev_dtype_like(im)
936
+ kernel = _prepare_kernel(ww, device, dtype)
937
+
938
+ with _NsideContext(self, nside):
939
+ # (A) Fast path: no ids provided -> delegate fully to _Convol_Torch
940
+ if cell_ids is None:
941
+ return self._Convol_Torch(im, kernel, cell_ids=None)
942
+
943
+ # Normalise numpy/tensor ragged inputs
944
+ if isinstance(cell_ids, np.ndarray) and cell_ids.dtype == object:
945
+ cell_ids = list(cell_ids)
946
+
947
+ # (B) One shared grid for entire batch: 1-D ids
948
+ if isinstance(cell_ids, (np.ndarray, torch.Tensor)) and getattr(cell_ids, "ndim", 1) == 1:
949
+ return self._Convol_Torch(im, kernel, cell_ids=_to_np_ids(cell_ids))
950
+
951
+ # (C) Per-sample grids, same length: 2-D ids (B, K)
952
+ if isinstance(cell_ids, (np.ndarray, torch.Tensor)) and getattr(cell_ids, "ndim", 0) == 2:
953
+ B = im.shape[0]
954
+ if isinstance(cell_ids, torch.Tensor):
955
+ assert cell_ids.shape[0] == B, "cell_ids first dim must match batch size B"
956
+ ids2d = cell_ids.detach().cpu().numpy().astype(np.int64, copy=False)
957
+ else:
958
+ ids2d = np.asarray(cell_ids, dtype=np.int64)
959
+ assert ids2d.shape[0] == B, "cell_ids first dim must match batch size B"
960
+
961
+ outs = []
962
+ for b in range(B):
963
+ x_b = im[b:b+1] # (1, Ci, K_b)
964
+ ids_b = ids2d[b] # (K_b,)
965
+ y_b = self._Convol_Torch(x_b, kernel, cell_ids=ids_b) # (1, Co, K_b)
966
+ outs.append(y_b)
967
+ return torch.cat(outs, dim=0) # (B, Co, K)
968
+
969
+ # (D) Per-sample grids, variable length: list/tuple
970
+ if isinstance(cell_ids, (list, tuple)):
971
+ B = im.shape[0]
972
+ assert len(cell_ids) == B, "cell_ids list length must match batch size B"
973
+ outs = []
974
+ lengths = []
975
+ for b in range(B):
976
+ ids_b_np = _to_np_ids(cell_ids[b])
977
+ lengths.append(ids_b_np.size)
978
+ x_b = im[b:b+1] # (1, Ci, K_b)
979
+ y_b = self._Convol_Torch(x_b, kernel, cell_ids=ids_b_np) # (1, Co, K_b)
980
+ outs.append(y_b)
981
+ if len(set(lengths)) == 1:
982
+ return torch.cat(outs, dim=0) # (B, Co, K)
983
+ else:
984
+ return [y.squeeze(0) for y in outs] # list[(Co, K_b)]
985
+
986
+ raise TypeError("Unsupported type for cell_ids with tensor input.")
987
+
988
+ # Case: im is list/tuple of (Ci, K_b) tensors (var-length samples)
989
+ if isinstance(im, (list, tuple)):
990
+ assert isinstance(cell_ids, (list, tuple)) and len(cell_ids) == len(im), \
991
+ "When im is a list, cell_ids must be a list of same length."
992
+ assert len(im) > 0, "Empty list for `im`."
993
+
994
+ device, dtype = _dev_dtype_like(im[0])
995
+ kernel = _prepare_kernel(ww, device, dtype)
996
+
997
+ outs = []
998
+ with _NsideContext(self, nside):
999
+ lengths = []
1000
+ tmp = []
1001
+ for x_b, ids_b in zip(im, cell_ids):
1002
+ assert isinstance(x_b, torch.Tensor), "Each sample in `im` must be a torch.Tensor"
1003
+ assert x_b.device == device and x_b.dtype == dtype, "All samples must share device/dtype."
1004
+ x_b = x_b.unsqueeze(0) # (1, Ci, K_b)
1005
+ ids_b = _to_np_ids(ids_b)
1006
+ y_b = self._Convol_Torch(x_b, kernel, cell_ids=ids_b) # (1, Co, K_b)
1007
+ tmp.append(y_b)
1008
+ lengths.append(y_b.shape[-1])
1009
+ if len(set(lengths)) == 1:
1010
+ return torch.cat(tmp, dim=0) # (B, Co, K)
1011
+ else:
1012
+ return [y.squeeze(0) for y in tmp]
1013
+
1014
+ raise TypeError("`im` must be either a torch.Tensor (B,Ci,K) or a list of (Ci,K_b) tensors.")
1015
+
1016
+ def make_matrix(
1017
+ self,
1018
+ kernel: torch.Tensor,
1019
+ cell_ids=None,
1020
+ *,
1021
+ return_sparse_tensor: bool = False,
1022
+ chunk_k: int = 4096,
1023
+ ):
1024
+ """
1025
+ Build the sparse COO matrix M such that applying M to vec(data) reproduces
1026
+ the spherical convolution performed by Convol_torch/_Convol_Torch.
1027
+
1028
+ Supports single- and multi-gauge:
1029
+ - kernel shape (Ci, Co_g, P) -> shared across G gauges, output Co = G*Co_g
1030
+ - kernel shape (G, Ci, Co_g, P) -> per-gauge kernels, same output Co = G*Co_g
1031
+
1032
+ Parameters
1033
+ ----------
1034
+ kernel : torch.Tensor
1035
+ (Ci, Co_g, P) or (G, Ci, Co_g, P) with P = kernel_sz**2.
1036
+ Must be on the device/dtype where you want the resulting matrix.
1037
+ cell_ids : array-like of shape (K,) or torch.Tensor, optional
1038
+ Target pixel IDs (NESTED if self.nest=True).
1039
+ If None, uses the grid already cached in the class (fast path).
1040
+ If provided, we prepare geometry & sparse binding for these ids.
1041
+ return_sparse_tensor : bool, default False
1042
+ If True, return a coalesced torch.sparse_coo_tensor of shape (Co*K, Ci*K).
1043
+ Else, return (weights, indices, shape) where:
1044
+ - indices is a LongTensor of shape (2, nnz) with [row; col]
1045
+ - weights is a Tensor of shape (nnz,)
1046
+ - shape is the (rows, cols) tuple
1047
+ chunk_k : int, default 4096
1048
+ Chunk size over target pixels to limit peak memory.
1049
+
1050
+ Returns
1051
+ -------
1052
+ If return_sparse_tensor:
1053
+ M : torch.sparse_coo_tensor of shape (Co*K, Ci*K), coalesced
1054
+ else:
1055
+ weights : torch.Tensor (nnz,)
1056
+ indices : torch.LongTensor (2, nnz) with [row; col]
1057
+ shape : tuple[int, int] (Co*K, Ci*K)
1058
+
1059
+ Notes
1060
+ -----
1061
+ - The resulting matrix implements the same interpolation-and-mixing as the
1062
+ GPU path (gather 4 neighbors -> normalize -> apply spatial+channel kernel),
1063
+ and matches the output of Convol_torch for the same (kernel, cell_ids).
1064
+ - For multi-gauge, rows are grouped as concatenated gauges: first all
1065
+ Co_g channels for gauge 0 over all K, then gauge 1, etc.
1066
+ """
1067
+ import numpy as np
1068
+ import torch
1069
+ import healpy as hp
1070
+
1071
+ device = kernel.device
1072
+ k_dtype = kernel.dtype
1073
+
1074
+ # --- validate kernel & normalize shapes
1075
+ if kernel.dim() == 3:
1076
+ # shared across gauges
1077
+ Ci, Co_g, P = kernel.shape
1078
+ per_gauge = False
1079
+ elif kernel.dim() == 4:
1080
+ Gk, Ci, Co_g, P = kernel.shape
1081
+ per_gauge = True
1082
+ if hasattr(self, 'G'):
1083
+ assert Gk == self.G, f"kernel first dim G={Gk} must match self.G={self.G}"
1084
+ else:
1085
+ self.G = int(Gk)
1086
+ else:
1087
+ raise ValueError("kernel must be (Ci,Co_g,P) or (G,Ci,Co_g,P)")
1088
+
1089
+ assert P == self.P, f"kernel P={P} must equal kernel_sz**2={self.P}"
1090
+
1091
+ # --- geometry + binding for these ids (or use cached)
1092
+ def _to_np_ids(ids):
1093
+ if ids is None:
1094
+ return None
1095
+ if isinstance(ids, torch.Tensor):
1096
+ return ids.detach().cpu().numpy().astype(np.int64, copy=False).reshape(-1)
1097
+ return np.asarray(ids, dtype=np.int64).reshape(-1)
1098
+
1099
+ cell_ids_np = _to_np_ids(cell_ids)
1100
+
1101
+ if cell_ids_np is not None:
1102
+ # Step A: geometry (Torch) with the class' number of gauges
1103
+ G = int(getattr(self, 'G', 1))
1104
+ th, ph = hp.pix2ang(self.nside, cell_ids_np, nest=self.nest)
1105
+ self.prepare_torch(th, ph, alpha=None, G=G)
1106
+
1107
+ # Step B: bind on sorted ids, and remember K
1108
+ order = np.argsort(cell_ids_np)
1109
+ ids_sorted_np = cell_ids_np[order]
1110
+ K = ids_sorted_np.size
1111
+
1112
+ if G > 1:
1113
+ self.bind_support_torch_multi(ids_sorted_np, device=device, dtype=k_dtype)
1114
+ else:
1115
+ self.bind_support_torch(ids_sorted_np, device=device, dtype=k_dtype)
1116
+ else:
1117
+ # use cached mapping
1118
+ if getattr(self, 'ids_sorted_np', None) is None:
1119
+ raise AssertionError("No cached targets; pass `cell_ids` or init the class with `cell_ids=`.")
1120
+ K = self.ids_sorted_np.size
1121
+ # rebind to the kernel device/dtype if needed
1122
+ if getattr(self, 'G', 1) > 1:
1123
+ if (self.device != device) or (self.dtype != k_dtype):
1124
+ self.bind_support_torch_multi(self.ids_sorted_np, device=device, dtype=k_dtype)
1125
+ else:
1126
+ if (self.device != device) or (self.dtype != k_dtype):
1127
+ self.bind_support_torch(self.ids_sorted_np, device=device, dtype=k_dtype)
1128
+
1129
+ G = int(getattr(self, 'G', 1))
1130
+ Co_total = (G * Co_g) # output channels including gauges
1131
+ shape = (Co_total * K, Ci * K)
1132
+
1133
+ # --- choose mapping tensors (multi vs single)
1134
+ is_multi = (G > 1) and (getattr(self, 'pos_safe_t_multi', None) is not None)
1135
+ if is_multi:
1136
+ pos_all_g = self.pos_safe_t_multi.to(device=device) # (G,4,K*P)
1137
+ w_all_g = self.w_norm_t_multi.to(device=device, dtype=k_dtype)
1138
+ else:
1139
+ pos_all = self.pos_safe_t.to(device=device) # (4,K*P)
1140
+ w_all = self.w_norm_t.to(device=device, dtype=k_dtype)
1141
+
1142
+ # --- precompute channel row/col bases
1143
+ # rows: for (co_total, k_out) -> co_total*K + k_out
1144
+ # cols: for (ci, k_in) -> ci*K + k_in
1145
+ row_base = (torch.arange(Co_total, device=device, dtype=torch.long) * K)[:, None] # (Co_total, 1)
1146
+ col_base = (torch.arange(Ci, device=device, dtype=torch.long) * K)[:, None] # (Ci, 1)
1147
+
1148
+
1149
+ rows_all, cols_all, vals_all = [], [], []
1150
+
1151
+ # --- helper to add one gauge block (gauge g -> Co_g*K rows)
1152
+ def _accumulate_for_gauge(g, pos_g, w_g, ker_g):
1153
+ """
1154
+ pos_g : (4, K*P) long
1155
+ w_g : (4, K*P) float
1156
+ ker_g : (Ci, Co_g, P)
1157
+ """
1158
+ # process by chunks in k to control memory
1159
+ for start in range(0, K, chunk_k):
1160
+ stop = min(start + chunk_k, K)
1161
+ Kb = stop - start
1162
+ cols_span = torch.arange(start * self.P, stop * self.P, device=device, dtype=torch.long)
1163
+
1164
+ pos = pos_g[:, cols_span].view(4, Kb, self.P) # (4, Kb, P)
1165
+ w = w_g[:, cols_span].view(4, Kb, self.P) # (4, Kb, P)
1166
+
1167
+ # rows_gauge: indices de lignes pour cette jauge g
1168
+ # Chaque jauge occupe un bloc de Co_g canaux de sortie pour CHAQUE pixel (K)
1169
+ # donc offset = g*Co_g
1170
+ rows_gauge = (torch.arange(Co_g, device=device, dtype=torch.long) + g*Co_g)[:, None] * K \
1171
+ + (start + torch.arange(Kb, device=device, dtype=torch.long))[None, :]
1172
+ # -> shape (Co_g, Kb)
1173
+ rows = rows_gauge[:, :, None, None, None] # (Co_g, Kb,1,1,1)
1174
+ rows = rows.expand(Co_g, Kb, Ci, 4, self.P) # (Co_g, Kb, Ci, 4, P)
1175
+
1176
+ # cols: indices colonnes = (ci*K + pix)
1177
+ cols_pix = pos.permute(1, 0, 2) # (Kb, 4, P)
1178
+ cols_pix = cols_pix[None, :, None, :, :] # (1, Kb, 1, 4, P)
1179
+ cols = col_base + cols_pix # (Ci, Kb, 1, 4, P)
1180
+ cols = cols.permute(2, 1, 0, 3, 4) # (1, Kb, Ci, 4, P)
1181
+ cols = cols.expand(Co_g, Kb, Ci, 4, self.P)
1182
+
1183
+ # values = kernel(ci, co_g, p) * w(4,kb,p)
1184
+ k_exp = ker_g.permute(1, 0, 2) # (Co_g, Ci, P)
1185
+ k_exp = k_exp[:, None, :, None, :] # (Co_g, 1, Ci, 1, P)
1186
+
1187
+ # CORRECTION: remettre les axes de w en (Kb,4,P) avant broadcast
1188
+ w_exp = w.permute(1, 0, 2)[None, :, None, :, :] # (1, Kb, 1, 4, P)
1189
+ w_exp = w_exp.expand(Co_g, Kb, Ci, 4, self.P) # (Co_g, Kb, Ci, 4, P)
1190
+
1191
+ vals = k_exp * w_exp # (Co_g, Kb, Ci, 4, P)
1192
+
1193
+ rows_all.append(rows.reshape(-1))
1194
+ cols_all.append(cols.reshape(-1))
1195
+ vals_all.append(vals.reshape(-1))
1196
+
1197
+
1198
+ # --- accumulate either single- or multi-gauge
1199
+ if is_multi:
1200
+ # (a) shared kernel (Ci, Co_g, P) -> repeat over gauges
1201
+ if not per_gauge and kernel.dim() == 3:
1202
+ for g in range(G):
1203
+ _accumulate_for_gauge(g, pos_all_g[g], w_all_g[g], kernel.to(device=device, dtype=k_dtype))
1204
+ # (b) per-gauge kernel (G, Ci, Co_g, P)
1205
+ else:
1206
+ for g in range(G):
1207
+ _accumulate_for_gauge(g, pos_all_g[g], w_all_g[g], kernel[g].to(device=device, dtype=k_dtype))
1208
+ else:
1209
+ # G == 1 (single-gauge path)
1210
+ g = 0
1211
+ _accumulate_for_gauge(g, pos_all, w_all, kernel if kernel.dim() == 3 else kernel[0])
1212
+
1213
+ rows = torch.cat(rows_all, dim=0)
1214
+ cols = torch.cat(cols_all, dim=0)
1215
+ vals = torch.cat(vals_all, dim=0)
1216
+
1217
+
1218
+ indices = torch.stack([cols, rows], dim=0) # (2, nnz) invert rows/cols for foscat needs
1219
+
1220
+ if return_sparse_tensor:
1221
+ M = torch.sparse_coo_tensor(indices, vals, size=shape, device=device, dtype=k_dtype).coalesce()
1222
+ return M
1223
+ else:
1224
+ return vals, indices, shape
1225
+
1226
+
1227
+ def _to_numpy_1d(self, ids):
1228
+ """Return a 1D numpy array of int64 for a single set of cell ids."""
1229
+ import numpy as np, torch
1230
+ if isinstance(ids, np.ndarray):
1231
+ return ids.reshape(-1).astype(np.int64, copy=False)
1232
+ if torch.is_tensor(ids):
1233
+ return ids.detach().cpu().to(torch.long).view(-1).numpy()
1234
+ # python list/tuple of ints
1235
+ return np.asarray(ids, dtype=np.int64).reshape(-1)
1236
+
1237
+ def _is_varlength_batch(self, ids):
1238
+ """
1239
+ True if ids is a list/tuple of per-sample id arrays (var-length batch).
1240
+ False if ids is a single array/tensor of ids (shared for whole batch).
1241
+ """
1242
+ import numpy as np, torch
1243
+ if isinstance(ids, (list, tuple)):
1244
+ return True
1245
+ if isinstance(ids, np.ndarray) and ids.ndim == 2:
1246
+ # This would be a dense (B, Npix) matrix -> NOT var-length list
1247
+ return False
1248
+ if torch.is_tensor(ids) and ids.dim() == 2:
1249
+ return False
1250
+ return False
1251
+
1252
+ def Down(self, im, cell_ids=None, nside=None,max_poll=False):
1253
+ """
1254
+ If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
1255
+ If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
1256
+ """
1257
+ if self.f is None:
1258
+ if self.dtype==torch.float64:
1259
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
1260
+ else:
1261
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
1262
+
1263
+ if cell_ids is None:
1264
+ dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=max_poll)
1265
+ return dim,cdim
1266
+
1267
+ if nside is None:
1268
+ nside = self.nside
1269
+
1270
+ # var-length mode: list/tuple of ids, one per sample
1271
+ if self._is_varlength_batch(cell_ids):
1272
+ outs, outs_ids = [], []
1273
+ B = len(cell_ids)
1274
+ for b in range(B):
1275
+ cid_b = self._to_numpy_1d(cell_ids[b])
1276
+ # extraire le bon échantillon d'`im`
1277
+ if torch.is_tensor(im):
1278
+ xb = im[b:b+1] # (1, C, N_b)
1279
+ yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
1280
+ outs.append(yb.squeeze(0)) # (C, N_b')
1281
+ else:
1282
+ # si im est déjà une liste de (C, N_b)
1283
+ xb = im[b]
1284
+ yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
1285
+ outs.append(yb.squeeze(0))
1286
+ outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
1287
+ return outs, outs_ids
1288
+
1289
+ # grille commune (un seul vecteur d'ids)
1290
+ cid = self._to_numpy_1d(cell_ids)
1291
+ return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
1292
+
1293
+ def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
1294
+ """
1295
+ If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
1296
+ If they are lists (var-length per sample) -> return list[Tensor].
1297
+ """
1298
+ if self.f is None:
1299
+ if self.dtype==torch.float64:
1300
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
1301
+ else:
1302
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
1303
+
1304
+ if cell_ids is None:
1305
+ dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
1306
+ return dim
1307
+
1308
+ if nside is None:
1309
+ nside = self.nside
1310
+
1311
+ # var-length: listes parallèles
1312
+ if self._is_varlength_batch(cell_ids):
1313
+ assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
1314
+ "In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
1315
+ outs = []
1316
+ B = len(cell_ids)
1317
+ for b in range(B):
1318
+ cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
1319
+ ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
1320
+ if torch.is_tensor(im):
1321
+ xb = im[b:b+1] # (1, C, N_b_coarse)
1322
+ yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
1323
+ o_cell_ids=ocid_b, force_init_index=True)
1324
+ outs.append(yb.squeeze(0)) # (C, N_b_fine)
1325
+ else:
1326
+ xb = im[b] # (C, N_b_coarse)
1327
+ yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
1328
+ o_cell_ids=ocid_b, force_init_index=True)
1329
+ outs.append(yb.squeeze(0))
1330
+ return outs
1331
+
1332
+ # grille commune
1333
+ cid = self._to_numpy_1d(cell_ids)
1334
+ ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
1335
+ return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
1336
+ o_cell_ids=ocid, force_init_index=True)
1337
+
1338
+
1339
+ def to_tensor(self,x):
1340
+ return torch.tensor(x,device=self.device,dtype=self.dtype)
1341
+
1342
+ def to_numpy(self,x):
1343
+ if isinstance(x,np.ndarray):
1344
+ return x
1345
+ return x.cpu().numpy()
1346
+