foscat 2025.10.2__py3-none-any.whl → 2026.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +635 -141
- foscat/FoCUS.py +135 -52
- foscat/SphereDownGeo.py +380 -0
- foscat/SphereUpGeo.py +175 -0
- foscat/SphericalStencil.py +27 -246
- foscat/alm_loc.py +270 -0
- foscat/scat.py +1 -1
- foscat/scat1D.py +1 -1
- foscat/scat_cov.py +24 -24
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/METADATA +1 -69
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/RECORD +14 -11
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/WHEEL +1 -1
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/top_level.txt +0 -0
foscat/SphericalStencil.py
CHANGED
|
@@ -2,13 +2,8 @@
|
|
|
2
2
|
# Author: J.-M. Delouis
|
|
3
3
|
import numpy as np
|
|
4
4
|
import healpy as hp
|
|
5
|
-
import foscat.scat_cov as sc
|
|
6
5
|
import torch
|
|
7
6
|
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
import healpy as hp
|
|
11
|
-
|
|
12
7
|
|
|
13
8
|
class SphericalStencil:
|
|
14
9
|
"""
|
|
@@ -61,8 +56,7 @@ class SphericalStencil:
|
|
|
61
56
|
device=None,
|
|
62
57
|
dtype=None,
|
|
63
58
|
n_gauges=1,
|
|
64
|
-
gauge_type='
|
|
65
|
-
scat_op=None,
|
|
59
|
+
gauge_type='phi',
|
|
66
60
|
):
|
|
67
61
|
assert kernel_sz >= 1 and int(kernel_sz) == kernel_sz
|
|
68
62
|
assert kernel_sz % 2 == 1, "kernel_sz must be odd"
|
|
@@ -75,10 +69,6 @@ class SphericalStencil:
|
|
|
75
69
|
self.gauge_type=gauge_type
|
|
76
70
|
|
|
77
71
|
self.nest = bool(nest)
|
|
78
|
-
if scat_op is None:
|
|
79
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ)
|
|
80
|
-
else:
|
|
81
|
-
self.f=scat_op
|
|
82
72
|
|
|
83
73
|
# Torch defaults
|
|
84
74
|
if device is None:
|
|
@@ -354,10 +344,27 @@ class SphericalStencil:
|
|
|
354
344
|
# --- build the local (P,3) stencil once on device
|
|
355
345
|
P = self.P
|
|
356
346
|
vec_np = np.zeros((P, 3), dtype=float)
|
|
357
|
-
grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2)
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
347
|
+
grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2)
|
|
348
|
+
|
|
349
|
+
# NEW: angular offsets
|
|
350
|
+
xx,yy=np.meshgrid(grid,grid)
|
|
351
|
+
s=1.0 # could be modified
|
|
352
|
+
alpha_pix = hp.nside2resol(self.nside, arcmin=False) # ~ taille angulaire typique
|
|
353
|
+
dtheta = (np.sqrt(xx**2+yy**2) * alpha_pix * s).ravel()
|
|
354
|
+
dphi = (np.arctan2(yy,xx)).ravel()
|
|
355
|
+
# local spherical displacement
|
|
356
|
+
# convert to unit vectors
|
|
357
|
+
x = np.sin(dtheta) * np.cos(dphi)
|
|
358
|
+
y = np.sin(dtheta) * np.sin(dphi)
|
|
359
|
+
z = np.cos(dtheta)
|
|
360
|
+
#print(self.nside*x.reshape(self.KERNELSZ,self.KERNELSZ))
|
|
361
|
+
#print(self.nside*y.reshape(self.KERNELSZ,self.KERNELSZ))
|
|
362
|
+
#print(self.nside*z.reshape(self.KERNELSZ,self.KERNELSZ))
|
|
363
|
+
vec_np = np.stack([x, y, z], axis=-1)
|
|
364
|
+
|
|
365
|
+
#vec_np[:, 0] = np.tile(grid, self.KERNELSZ)
|
|
366
|
+
#vec_np[:, 1] = np.repeat(grid, self.KERNELSZ)
|
|
367
|
+
#vec_np[:, 2] = 1.0 - np.sqrt(vec_np[:, 0]**2 + vec_np[:, 1]**2)
|
|
361
368
|
vec_t = torch.as_tensor(vec_np, device=self.device, dtype=self.dtype) # (P,3)
|
|
362
369
|
|
|
363
370
|
# --- rotation matrices for all targets & gauges: (K,G,3,3)
|
|
@@ -371,7 +378,7 @@ class SphericalStencil:
|
|
|
371
378
|
th, ph, alpha, G=self.G, gauge_cosmo=(self.gauge_type=='cosmo'),
|
|
372
379
|
device=self.device, dtype=self.dtype
|
|
373
380
|
) # shape (K,G,3,3)
|
|
374
|
-
|
|
381
|
+
|
|
375
382
|
# --- rotate stencil for each (target, gauge): (K,G,P,3)
|
|
376
383
|
# einsum over local stencil (P,3) with rotation (K,G,3,3)
|
|
377
384
|
rotated = torch.einsum('kgij,pj->kgpi', R_t, vec_t) # (K,G,P,3)
|
|
@@ -568,119 +575,6 @@ class SphericalStencil:
|
|
|
568
575
|
self.dtype = dtype
|
|
569
576
|
|
|
570
577
|
|
|
571
|
-
'''
|
|
572
|
-
def bind_support_torch_multi(self, ids_sorted_np, *, device=None, dtype=None):
|
|
573
|
-
"""
|
|
574
|
-
Multi-gauge sparse binding (Step B).
|
|
575
|
-
Uses self.idx_t_multi / self.w_t_multi prepared by prepare_torch(..., G>1)
|
|
576
|
-
and builds, for each gauge g, (pos_safe, w_norm, present).
|
|
577
|
-
|
|
578
|
-
Parameters
|
|
579
|
-
----------
|
|
580
|
-
ids_sorted_np : np.ndarray (K,)
|
|
581
|
-
Sorted pixel ids for available samples (matches the last axis of your data).
|
|
582
|
-
device, dtype : torch device/dtype for the produced mapping tensors.
|
|
583
|
-
|
|
584
|
-
Side effects
|
|
585
|
-
------------
|
|
586
|
-
Sets:
|
|
587
|
-
- self.ids_sorted_np : (K,)
|
|
588
|
-
- self.pos_safe_t_multi : (G, 4, K*P) LongTensor
|
|
589
|
-
- self.w_norm_t_multi : (G, 4, K*P) Tensor
|
|
590
|
-
- self.present_t_multi : (G, 4, K*P) BoolTensor
|
|
591
|
-
- (and mirrors device/dtype in self.device/self.dtype)
|
|
592
|
-
"""
|
|
593
|
-
assert hasattr(self, 'idx_t_multi') and self.idx_t_multi is not None, \
|
|
594
|
-
"Call prepare_torch(..., G>0) before bind_support_torch_multi(...)"
|
|
595
|
-
assert hasattr(self, 'w_t_multi') and self.w_t_multi is not None
|
|
596
|
-
|
|
597
|
-
if device is None: device = self.device
|
|
598
|
-
if dtype is None: dtype = self.dtype
|
|
599
|
-
|
|
600
|
-
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64).reshape(-1)
|
|
601
|
-
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
|
|
602
|
-
|
|
603
|
-
G, _, M = self.idx_t_multi.shape
|
|
604
|
-
K = self.Kb
|
|
605
|
-
P = self.P
|
|
606
|
-
assert M == K*P, "idx_t_multi second axis must have K*P columns"
|
|
607
|
-
|
|
608
|
-
pos_list, present_list, wnorm_list = [], [], []
|
|
609
|
-
|
|
610
|
-
for g in range(G):
|
|
611
|
-
idx = self.idx_t_multi[g].to(device=device, dtype=torch.long) # (4, M)
|
|
612
|
-
w = self.w_t_multi[g].to(device=device, dtype=dtype) # (4, M)
|
|
613
|
-
|
|
614
|
-
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
|
|
615
|
-
in_range = pos < ids_sorted.numel()
|
|
616
|
-
cmp_vals = torch.full_like(idx, -1)
|
|
617
|
-
cmp_vals[in_range] = ids_sorted[pos[in_range]]
|
|
618
|
-
present = (cmp_vals == idx)
|
|
619
|
-
|
|
620
|
-
# normalize weights per column after masking
|
|
621
|
-
w = w * present
|
|
622
|
-
colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
|
|
623
|
-
w_norm = w / colsum
|
|
624
|
-
|
|
625
|
-
pos_safe = torch.where(present, pos, torch.zeros_like(pos))
|
|
626
|
-
|
|
627
|
-
pos_list.append(pos_safe)
|
|
628
|
-
present_list.append(present)
|
|
629
|
-
wnorm_list.append(w_norm)
|
|
630
|
-
|
|
631
|
-
self.pos_safe_t_multi = torch.stack(pos_list, dim=0) # (G, 4, M)
|
|
632
|
-
self.present_t_multi = torch.stack(present_list, dim=0) # (G, 4, M)
|
|
633
|
-
self.w_norm_t_multi = torch.stack(wnorm_list, dim=0) # (G, 4, M)
|
|
634
|
-
|
|
635
|
-
# mirror runtime placement
|
|
636
|
-
self.device = device
|
|
637
|
-
self.dtype = dtype
|
|
638
|
-
|
|
639
|
-
# ------------------------------------------------------------------
|
|
640
|
-
# Step B: bind support Torch
|
|
641
|
-
# ------------------------------------------------------------------
|
|
642
|
-
def bind_support_torch(self, ids_sorted_np, *, device=None, dtype=None):
|
|
643
|
-
"""
|
|
644
|
-
Map HEALPix neighbor indices (from Step A) to actual data samples
|
|
645
|
-
sorted by pixel id. Produces pos_safe and normalized weights.
|
|
646
|
-
|
|
647
|
-
Parameters
|
|
648
|
-
----------
|
|
649
|
-
ids_sorted_np : np.ndarray (K,)
|
|
650
|
-
Sorted pixel ids for available data.
|
|
651
|
-
device, dtype : Torch device/dtype for results.
|
|
652
|
-
"""
|
|
653
|
-
if device is None:
|
|
654
|
-
device = self.device
|
|
655
|
-
if dtype is None:
|
|
656
|
-
dtype = self.dtype
|
|
657
|
-
|
|
658
|
-
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64)
|
|
659
|
-
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
|
|
660
|
-
|
|
661
|
-
idx = self.idx_t.to(device=device, dtype=torch.long)
|
|
662
|
-
w = self.w_t.to(device=device, dtype=dtype)
|
|
663
|
-
|
|
664
|
-
M = self.Kb * self.P
|
|
665
|
-
idx = idx.view(4, M)
|
|
666
|
-
w = w.view(4, M)
|
|
667
|
-
|
|
668
|
-
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
|
|
669
|
-
in_range = pos < ids_sorted.shape[0]
|
|
670
|
-
cmp_vals = torch.full_like(idx, -1)
|
|
671
|
-
cmp_vals[in_range] = ids_sorted[pos[in_range]]
|
|
672
|
-
present = (cmp_vals == idx)
|
|
673
|
-
|
|
674
|
-
w = w * present
|
|
675
|
-
colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
|
|
676
|
-
w_norm = w / colsum
|
|
677
|
-
|
|
678
|
-
self.pos_safe_t = torch.where(present, pos, torch.zeros_like(pos))
|
|
679
|
-
self.w_norm_t = w_norm
|
|
680
|
-
self.present_t = present
|
|
681
|
-
self.device = device
|
|
682
|
-
self.dtype = dtype
|
|
683
|
-
'''
|
|
684
578
|
# ------------------------------------------------------------------
|
|
685
579
|
# Step C: apply convolution (already Torch in your code)
|
|
686
580
|
# ------------------------------------------------------------------
|
|
@@ -1215,7 +1109,7 @@ class SphericalStencil:
|
|
|
1215
1109
|
vals = torch.cat(vals_all, dim=0)
|
|
1216
1110
|
|
|
1217
1111
|
|
|
1218
|
-
indices = torch.stack([cols, rows], dim=0)
|
|
1112
|
+
indices = torch.stack([cols, rows], dim=0)
|
|
1219
1113
|
|
|
1220
1114
|
if return_sparse_tensor:
|
|
1221
1115
|
M = torch.sparse_coo_tensor(indices, vals, size=shape, device=device, dtype=k_dtype).coalesce()
|
|
@@ -1224,123 +1118,10 @@ class SphericalStencil:
|
|
|
1224
1118
|
return vals, indices, shape
|
|
1225
1119
|
|
|
1226
1120
|
|
|
1227
|
-
def _to_numpy_1d(self, ids):
|
|
1228
|
-
"""Return a 1D numpy array of int64 for a single set of cell ids."""
|
|
1229
|
-
import numpy as np, torch
|
|
1230
|
-
if isinstance(ids, np.ndarray):
|
|
1231
|
-
return ids.reshape(-1).astype(np.int64, copy=False)
|
|
1232
|
-
if torch.is_tensor(ids):
|
|
1233
|
-
return ids.detach().cpu().to(torch.long).view(-1).numpy()
|
|
1234
|
-
# python list/tuple of ints
|
|
1235
|
-
return np.asarray(ids, dtype=np.int64).reshape(-1)
|
|
1236
|
-
|
|
1237
|
-
def _is_varlength_batch(self, ids):
|
|
1238
|
-
"""
|
|
1239
|
-
True if ids is a list/tuple of per-sample id arrays (var-length batch).
|
|
1240
|
-
False if ids is a single array/tensor of ids (shared for whole batch).
|
|
1241
|
-
"""
|
|
1242
|
-
import numpy as np, torch
|
|
1243
|
-
if isinstance(ids, (list, tuple)):
|
|
1244
|
-
return True
|
|
1245
|
-
if isinstance(ids, np.ndarray) and ids.ndim == 2:
|
|
1246
|
-
# This would be a dense (B, Npix) matrix -> NOT var-length list
|
|
1247
|
-
return False
|
|
1248
|
-
if torch.is_tensor(ids) and ids.dim() == 2:
|
|
1249
|
-
return False
|
|
1250
|
-
return False
|
|
1251
|
-
|
|
1252
|
-
def Down(self, im, cell_ids=None, nside=None,max_poll=False):
|
|
1253
|
-
"""
|
|
1254
|
-
If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
|
|
1255
|
-
If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
|
|
1256
|
-
"""
|
|
1257
|
-
if self.f is None:
|
|
1258
|
-
if self.dtype==torch.float64:
|
|
1259
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
1260
|
-
else:
|
|
1261
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
1262
|
-
|
|
1263
|
-
if cell_ids is None:
|
|
1264
|
-
dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=max_poll)
|
|
1265
|
-
return dim,cdim
|
|
1266
|
-
|
|
1267
|
-
if nside is None:
|
|
1268
|
-
nside = self.nside
|
|
1269
|
-
|
|
1270
|
-
# var-length mode: list/tuple of ids, one per sample
|
|
1271
|
-
if self._is_varlength_batch(cell_ids):
|
|
1272
|
-
outs, outs_ids = [], []
|
|
1273
|
-
B = len(cell_ids)
|
|
1274
|
-
for b in range(B):
|
|
1275
|
-
cid_b = self._to_numpy_1d(cell_ids[b])
|
|
1276
|
-
# extraire le bon échantillon d'`im`
|
|
1277
|
-
if torch.is_tensor(im):
|
|
1278
|
-
xb = im[b:b+1] # (1, C, N_b)
|
|
1279
|
-
yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
1280
|
-
outs.append(yb.squeeze(0)) # (C, N_b')
|
|
1281
|
-
else:
|
|
1282
|
-
# si im est déjà une liste de (C, N_b)
|
|
1283
|
-
xb = im[b]
|
|
1284
|
-
yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
|
|
1285
|
-
outs.append(yb.squeeze(0))
|
|
1286
|
-
outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
|
|
1287
|
-
return outs, outs_ids
|
|
1288
|
-
|
|
1289
|
-
# grille commune (un seul vecteur d'ids)
|
|
1290
|
-
cid = self._to_numpy_1d(cell_ids)
|
|
1291
|
-
return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
|
|
1292
|
-
|
|
1293
|
-
def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
|
|
1294
|
-
"""
|
|
1295
|
-
If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
|
|
1296
|
-
If they are lists (var-length per sample) -> return list[Tensor].
|
|
1297
|
-
"""
|
|
1298
|
-
if self.f is None:
|
|
1299
|
-
if self.dtype==torch.float64:
|
|
1300
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
1301
|
-
else:
|
|
1302
|
-
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
1303
|
-
|
|
1304
|
-
if cell_ids is None:
|
|
1305
|
-
dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
|
|
1306
|
-
return dim
|
|
1307
|
-
|
|
1308
|
-
if nside is None:
|
|
1309
|
-
nside = self.nside
|
|
1310
|
-
|
|
1311
|
-
# var-length: listes parallèles
|
|
1312
|
-
if self._is_varlength_batch(cell_ids):
|
|
1313
|
-
assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
|
|
1314
|
-
"In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
|
|
1315
|
-
outs = []
|
|
1316
|
-
B = len(cell_ids)
|
|
1317
|
-
for b in range(B):
|
|
1318
|
-
cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
|
|
1319
|
-
ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
|
|
1320
|
-
if torch.is_tensor(im):
|
|
1321
|
-
xb = im[b:b+1] # (1, C, N_b_coarse)
|
|
1322
|
-
yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
|
|
1323
|
-
o_cell_ids=ocid_b, force_init_index=True)
|
|
1324
|
-
outs.append(yb.squeeze(0)) # (C, N_b_fine)
|
|
1325
|
-
else:
|
|
1326
|
-
xb = im[b] # (C, N_b_coarse)
|
|
1327
|
-
yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
|
|
1328
|
-
o_cell_ids=ocid_b, force_init_index=True)
|
|
1329
|
-
outs.append(yb.squeeze(0))
|
|
1330
|
-
return outs
|
|
1331
|
-
|
|
1332
|
-
# grille commune
|
|
1333
|
-
cid = self._to_numpy_1d(cell_ids)
|
|
1334
|
-
ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
|
|
1335
|
-
return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
|
|
1336
|
-
o_cell_ids=ocid, force_init_index=True)
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
1121
|
def to_tensor(self,x):
|
|
1340
|
-
return torch.tensor(x,device=
|
|
1341
|
-
|
|
1122
|
+
return torch.tensor(x,device='cuda')
|
|
1123
|
+
|
|
1342
1124
|
def to_numpy(self,x):
|
|
1343
1125
|
if isinstance(x,np.ndarray):
|
|
1344
1126
|
return x
|
|
1345
|
-
return x.cpu().numpy()
|
|
1346
|
-
|
|
1127
|
+
return x.cpu().numpy()
|
foscat/alm_loc.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import healpy as hp
|
|
4
|
+
|
|
5
|
+
from foscat.alm import alm as _alm
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
class alm_loc(_alm):
|
|
9
|
+
"""
|
|
10
|
+
Local/partial-sky variant of foscat.alm.alm.
|
|
11
|
+
|
|
12
|
+
Key design choice (to match alm.py exactly when full-sky is provided):
|
|
13
|
+
- Reuse *all* Legendre/normalization machinery from the parent class (alm),
|
|
14
|
+
i.e. shift_ph(), compute_legendre_m(), ratio_mm, A/B recurrences, etc.
|
|
15
|
+
This is critical for matching alm.map2alm() numerically.
|
|
16
|
+
|
|
17
|
+
Differences vs alm.map2alm():
|
|
18
|
+
- Input map is [..., n] with explicit (nside, cell_ids)
|
|
19
|
+
- Only rings touched by cell_ids are processed.
|
|
20
|
+
- For rings with full coverage, we run the exact same FFT+tiling logic as alm.comp_tf()
|
|
21
|
+
(but only for those rings) -> bitwise comparable up to backend FFT differences.
|
|
22
|
+
- For rings with partial coverage, we compute a *partial DFT* for m=0..mmax,
|
|
23
|
+
using the same phase convention as alm.comp_tf():
|
|
24
|
+
FFT kernel uses exp(-i 2pi (m mod Nring) j / Nring)
|
|
25
|
+
then apply the per-ring shift exp(-i m phi0) via self.matrix_shift_ph
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, backend=None, lmax=24, limit_range=1e10):
|
|
29
|
+
super().__init__(backend=backend, lmax=lmax, nside=None, limit_range=limit_range)
|
|
30
|
+
|
|
31
|
+
# --------- helpers: ring layout identical to alm.ring_th/ring_ph ----------
|
|
32
|
+
@staticmethod
|
|
33
|
+
def _ring_starts_sizes(nside: int):
|
|
34
|
+
starts = []
|
|
35
|
+
sizes = []
|
|
36
|
+
n = 0
|
|
37
|
+
for k in range(nside - 1):
|
|
38
|
+
N = 4 * (k + 1)
|
|
39
|
+
starts.append(n); sizes.append(N)
|
|
40
|
+
n += N
|
|
41
|
+
for _ in range(2 * nside + 1):
|
|
42
|
+
N = 4 * nside
|
|
43
|
+
starts.append(n); sizes.append(N)
|
|
44
|
+
n += N
|
|
45
|
+
for k in range(nside - 1):
|
|
46
|
+
N = 4 * (nside - 1 - k)
|
|
47
|
+
starts.append(n); sizes.append(N)
|
|
48
|
+
n += N
|
|
49
|
+
return np.asarray(starts, np.int64), np.asarray(sizes, np.int32)
|
|
50
|
+
|
|
51
|
+
def _to_ring_ids(self, nside: int, cell_ids: np.ndarray, nest: bool) -> np.ndarray:
|
|
52
|
+
if nest:
|
|
53
|
+
return hp.nest2ring(nside, cell_ids)
|
|
54
|
+
return cell_ids
|
|
55
|
+
|
|
56
|
+
def _group_by_ring(self, nside: int, ring_ids: np.ndarray):
|
|
57
|
+
"""
|
|
58
|
+
Returns:
|
|
59
|
+
ring_idx: ring number (0..4*nside-2) per pixel
|
|
60
|
+
pos: position along ring (0..Nring-1) per pixel
|
|
61
|
+
order: sort order grouping by ring then pos
|
|
62
|
+
starts,sizes: ring layout
|
|
63
|
+
"""
|
|
64
|
+
starts, sizes = self._ring_starts_sizes(nside)
|
|
65
|
+
|
|
66
|
+
# ring index = last start <= ring_id
|
|
67
|
+
ring_idx = np.searchsorted(starts, ring_ids, side="right") - 1
|
|
68
|
+
ring_idx = ring_idx.astype(np.int32)
|
|
69
|
+
|
|
70
|
+
pos = (ring_ids - starts[ring_idx]).astype(np.int32)
|
|
71
|
+
|
|
72
|
+
order = np.lexsort((pos, ring_idx))
|
|
73
|
+
return ring_idx, pos, order, starts, sizes
|
|
74
|
+
|
|
75
|
+
# ------------------ local Fourier transform per ring ---------------------
|
|
76
|
+
def comp_tf_loc(self, im, nside: int, cell_ids, nest: bool = False, realfft: bool = True, mmax=None):
|
|
77
|
+
"""
|
|
78
|
+
Returns:
|
|
79
|
+
rings_used: 1D np.ndarray of ring indices present
|
|
80
|
+
ft: backend tensor of shape [..., nrings_used, mmax+1] (complex)
|
|
81
|
+
where last axis is m, ring axis matches rings_used order.
|
|
82
|
+
"""
|
|
83
|
+
nside = int(nside)
|
|
84
|
+
cell_ids = np.asarray(cell_ids, dtype=np.int64)
|
|
85
|
+
if mmax is None:
|
|
86
|
+
mmax = min(self.lmax, 3 * nside - 1)
|
|
87
|
+
mmax = int(mmax)
|
|
88
|
+
|
|
89
|
+
# Ensure parent caches for this nside exist (matrix_shift_ph, A/B, ratio_mm, etc.)
|
|
90
|
+
self.shift_ph(nside)
|
|
91
|
+
|
|
92
|
+
ring_ids = self._to_ring_ids(nside, cell_ids, nest)
|
|
93
|
+
ring_idx, pos, order, starts, sizes = self._group_by_ring(nside, ring_ids)
|
|
94
|
+
|
|
95
|
+
ring_idx = ring_idx[order]
|
|
96
|
+
pos = pos[order]
|
|
97
|
+
|
|
98
|
+
i_im = self.backend.bk_cast(im)
|
|
99
|
+
i_im = self.backend.bk_gather(i_im, order, axis=-1) # reorder last axis
|
|
100
|
+
|
|
101
|
+
rings_used, start_ptr, counts = np.unique(ring_idx, return_index=True, return_counts=True)
|
|
102
|
+
|
|
103
|
+
# Build output per ring as list then concat
|
|
104
|
+
out_per_ring = []
|
|
105
|
+
for r, s0, cnt in zip(rings_used.tolist(), start_ptr.tolist(), counts.tolist()):
|
|
106
|
+
Nring = int(sizes[r])
|
|
107
|
+
p = pos[s0:s0+cnt]
|
|
108
|
+
|
|
109
|
+
v = self.backend.bk_gather(i_im, np.arange(s0, s0+cnt, dtype=np.int64), axis=-1)
|
|
110
|
+
|
|
111
|
+
if cnt == Nring:
|
|
112
|
+
# Full ring: exact same FFT+tiling logic as alm.comp_tf for 1 ring
|
|
113
|
+
# Need data ordered by pos (already grouped, but ensure pos is 0..N-1)
|
|
114
|
+
if not np.all(p == np.arange(Nring, dtype=p.dtype)):
|
|
115
|
+
# reorder within ring
|
|
116
|
+
sub_order = np.argsort(p)
|
|
117
|
+
v = self.backend.bk_gather(v, sub_order, axis=-1)
|
|
118
|
+
|
|
119
|
+
if realfft:
|
|
120
|
+
tmp = self.rfft2fft(v)
|
|
121
|
+
else:
|
|
122
|
+
tmp = self.backend.bk_fft(v)
|
|
123
|
+
|
|
124
|
+
l_n = tmp.shape[-1]
|
|
125
|
+
if l_n < mmax + 1:
|
|
126
|
+
repeat_n = (mmax // l_n) + 1
|
|
127
|
+
tmp = self.backend.bk_tile(tmp, repeat_n, axis=-1)
|
|
128
|
+
|
|
129
|
+
tmp = tmp[..., :mmax+1]
|
|
130
|
+
|
|
131
|
+
# Apply per-ring shift exp(-i m phi0) exactly like alm.comp_tf
|
|
132
|
+
shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m]
|
|
133
|
+
tmp = tmp * shift
|
|
134
|
+
out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m]
|
|
135
|
+
else:
|
|
136
|
+
# Partial ring: partial DFT for required m, using same aliasing as FFT branch
|
|
137
|
+
m_vec = np.arange(mmax+1, dtype=np.int64)
|
|
138
|
+
m_mod = (m_vec % Nring).astype(np.int64)
|
|
139
|
+
|
|
140
|
+
# angles: 2pi * pos * m_mod / Nring
|
|
141
|
+
ang = (2.0 * np.pi / Nring) * p.astype(np.float64)[:, None] * m_mod[None, :].astype(np.float64)
|
|
142
|
+
ker = np.exp(-1j * ang).astype(np.complex128) # [cnt, m]
|
|
143
|
+
|
|
144
|
+
ker_bk = self.backend.bk_cast(ker)
|
|
145
|
+
|
|
146
|
+
# v is [..., cnt]; we want [..., m] = sum_cnt v*ker
|
|
147
|
+
tmp = self.backend.bk_reduce_sum(
|
|
148
|
+
self.backend.bk_expand_dims(v, axis=-1) * ker_bk,
|
|
149
|
+
axis=-2
|
|
150
|
+
) # [..., m]
|
|
151
|
+
|
|
152
|
+
shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m] true m shift
|
|
153
|
+
tmp = tmp * shift
|
|
154
|
+
out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m]
|
|
155
|
+
|
|
156
|
+
ft = self.backend.bk_concat(out_per_ring, axis=-2) # [..., nrings, m]
|
|
157
|
+
return np.asarray(rings_used, dtype=np.int32), ft
|
|
158
|
+
|
|
159
|
+
# ---------------------------- map -> alm --------------------------------
|
|
160
|
+
def map2alm_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
|
|
161
|
+
nside = int(nside)
|
|
162
|
+
if lmax is None:
|
|
163
|
+
lmax = min(self.lmax, 3 * nside - 1)
|
|
164
|
+
lmax = int(lmax)
|
|
165
|
+
|
|
166
|
+
# Ensure a batch dimension like alm.map2alm expects
|
|
167
|
+
_added_batch = False
|
|
168
|
+
if hasattr(im, 'ndim') and im.ndim == 1:
|
|
169
|
+
im = im[None, :]
|
|
170
|
+
_added_batch = True
|
|
171
|
+
elif (not hasattr(im, 'ndim')) and len(im.shape) == 1:
|
|
172
|
+
im = im[None, :]
|
|
173
|
+
_added_batch = True
|
|
174
|
+
|
|
175
|
+
rings_used, ft = self.comp_tf_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, realfft=True, mmax=lmax)
|
|
176
|
+
|
|
177
|
+
# cos(theta) on used rings
|
|
178
|
+
co_th = np.cos(self.ring_th(nside)[rings_used])
|
|
179
|
+
|
|
180
|
+
# ft is [..., R, m]
|
|
181
|
+
alm_out = None
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
for m in range(lmax + 1):
|
|
186
|
+
# IMPORTANT: reuse alm.compute_legendre_m and its normalization exactly
|
|
187
|
+
plm = self.compute_legendre_m(co_th, m, lmax, nside) / (12 * nside**2) # [L,R]
|
|
188
|
+
plm_bk = self.backend.bk_cast(plm)
|
|
189
|
+
|
|
190
|
+
ft_m = ft[..., :, m] # [..., R]
|
|
191
|
+
tmp = self.backend.bk_reduce_sum(
|
|
192
|
+
self.backend.bk_expand_dims(ft_m, axis=-2) * plm_bk,
|
|
193
|
+
axis=-1
|
|
194
|
+
) # [..., L]
|
|
195
|
+
l_vals = np.arange(m, lmax + 1, dtype=np.float64)
|
|
196
|
+
scale = np.sqrt(2.0 * l_vals + 1.0)
|
|
197
|
+
|
|
198
|
+
# convertir scale en backend tensor (torch) sur le bon device
|
|
199
|
+
scale_t = self.backend.bk_cast(scale) # ou un helper équivalent
|
|
200
|
+
# reshape pour broadcast si nécessaire: [1, L] ou [L]
|
|
201
|
+
shape = (1,) * (tmp.ndim - 1) + (scale_t.shape[0],)
|
|
202
|
+
scale_t = scale_t.reshape(shape)
|
|
203
|
+
|
|
204
|
+
tmp = tmp * scale_t
|
|
205
|
+
if m == 0:
|
|
206
|
+
alm_out = tmp
|
|
207
|
+
else:
|
|
208
|
+
alm_out = self.backend.bk_concat([alm_out, tmp], axis=-1)
|
|
209
|
+
if _added_batch:
|
|
210
|
+
alm_out = alm_out[0]
|
|
211
|
+
return alm_out
|
|
212
|
+
|
|
213
|
+
# ---------------------------- alm -> Cl ---------------------------------
|
|
214
|
+
def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
|
|
215
|
+
|
|
216
|
+
if lmax is None:
|
|
217
|
+
lmax = min(self.lmax, 3 * nside - 1)
|
|
218
|
+
lmax = int(lmax)
|
|
219
|
+
|
|
220
|
+
alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax)
|
|
221
|
+
|
|
222
|
+
# cl has same batch dims as alm, plus ell dim
|
|
223
|
+
batch_shape = alm.shape[:-1]
|
|
224
|
+
cl = torch.zeros(batch_shape + (lmax + 1,), dtype=torch.float64, device=alm.device)
|
|
225
|
+
|
|
226
|
+
idx = 0
|
|
227
|
+
for m in range(lmax + 1):
|
|
228
|
+
L = lmax - m + 1
|
|
229
|
+
a = alm[..., idx:idx+L] # shape: batch + (L,)
|
|
230
|
+
idx += L
|
|
231
|
+
|
|
232
|
+
p = self.backend.bk_real(a * self.backend.bk_conjugate(a)) # batch + (L,)
|
|
233
|
+
|
|
234
|
+
if m == 0:
|
|
235
|
+
cl[..., m:] += p
|
|
236
|
+
else:
|
|
237
|
+
cl[..., m:] += 2.0 * p
|
|
238
|
+
|
|
239
|
+
# divide by (2l+1), broadcast over batch dims
|
|
240
|
+
denom = (2 * torch.arange(lmax + 1, dtype=cl.dtype, device=alm.device) + 1) # (lmax+1,)
|
|
241
|
+
denom = denom.reshape((1,) * len(batch_shape) + (lmax + 1,)) # batch-broadcast
|
|
242
|
+
cl = cl / denom
|
|
243
|
+
return cl
|
|
244
|
+
'''
|
|
245
|
+
def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
|
|
246
|
+
if lmax is None:
|
|
247
|
+
lmax = min(self.lmax, 3 * nside - 1)
|
|
248
|
+
lmax = int(lmax)
|
|
249
|
+
|
|
250
|
+
alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax)
|
|
251
|
+
|
|
252
|
+
# Unpack and compute Cl with correct real-field folding:
|
|
253
|
+
cl = torch.zeros((lmax + 1,), dtype=alm.dtype, device=alm.device)
|
|
254
|
+
|
|
255
|
+
idx = 0
|
|
256
|
+
for m in range(lmax + 1):
|
|
257
|
+
L = lmax - m + 1
|
|
258
|
+
a = alm[..., idx:idx+L]
|
|
259
|
+
idx += L
|
|
260
|
+
p = self.backend.bk_real(a * self.backend.bk_conjugate(a))
|
|
261
|
+
# sum over any batch dims
|
|
262
|
+
p = self.backend.bk_reduce_sum(p, axis=tuple(range(p.ndim-1))) if p.ndim > 1 else p
|
|
263
|
+
if m == 0:
|
|
264
|
+
cl[m:] += p
|
|
265
|
+
else:
|
|
266
|
+
cl[m:] += 2.0 * p
|
|
267
|
+
denom = (2*torch.arange(lmax+1,dtype=p.dtype, device=alm.device)+1)
|
|
268
|
+
cl = cl / denom
|
|
269
|
+
return cl
|
|
270
|
+
'''
|
foscat/scat.py
CHANGED
|
@@ -1659,7 +1659,7 @@ class funct(FOC.FoCUS):
|
|
|
1659
1659
|
s2j2 = None
|
|
1660
1660
|
l2_image = None
|
|
1661
1661
|
for j1 in range(jmax):
|
|
1662
|
-
if j1 < jmax
|
|
1662
|
+
if j1 < jmax: # stop to add scales
|
|
1663
1663
|
# Convol image along the axis defined by 'axis' using the wavelet defined at
|
|
1664
1664
|
# the foscat initialisation
|
|
1665
1665
|
# c_image_real is [....,Npix_j1,....,Norient]
|
foscat/scat1D.py
CHANGED
|
@@ -1282,7 +1282,7 @@ class funct(FOC.FoCUS):
|
|
|
1282
1282
|
l2_image = None
|
|
1283
1283
|
|
|
1284
1284
|
for j1 in range(jmax):
|
|
1285
|
-
if j1 < jmax
|
|
1285
|
+
if j1 < jmax: # stop to add scales
|
|
1286
1286
|
# Convol image along the axis defined by 'axis' using the wavelet defined at
|
|
1287
1287
|
# the foscat initialisation
|
|
1288
1288
|
# c_image_real is [....,Npix_j1,....,Norient]
|