foscat 2025.10.2__py3-none-any.whl → 2026.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,13 +2,8 @@
2
2
  # Author: J.-M. Delouis
3
3
  import numpy as np
4
4
  import healpy as hp
5
- import foscat.scat_cov as sc
6
5
  import torch
7
6
 
8
- import numpy as np
9
- import torch
10
- import healpy as hp
11
-
12
7
 
13
8
  class SphericalStencil:
14
9
  """
@@ -61,8 +56,7 @@ class SphericalStencil:
61
56
  device=None,
62
57
  dtype=None,
63
58
  n_gauges=1,
64
- gauge_type='cosmo',
65
- scat_op=None,
59
+ gauge_type='phi',
66
60
  ):
67
61
  assert kernel_sz >= 1 and int(kernel_sz) == kernel_sz
68
62
  assert kernel_sz % 2 == 1, "kernel_sz must be odd"
@@ -75,10 +69,6 @@ class SphericalStencil:
75
69
  self.gauge_type=gauge_type
76
70
 
77
71
  self.nest = bool(nest)
78
- if scat_op is None:
79
- self.f=sc.funct(KERNELSZ=self.KERNELSZ)
80
- else:
81
- self.f=scat_op
82
72
 
83
73
  # Torch defaults
84
74
  if device is None:
@@ -354,10 +344,27 @@ class SphericalStencil:
354
344
  # --- build the local (P,3) stencil once on device
355
345
  P = self.P
356
346
  vec_np = np.zeros((P, 3), dtype=float)
357
- grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2) / self.nside
358
- vec_np[:, 0] = np.tile(grid, self.KERNELSZ)
359
- vec_np[:, 1] = np.repeat(grid, self.KERNELSZ)
360
- vec_np[:, 2] = 1.0 - np.sqrt(vec_np[:, 0]**2 + vec_np[:, 1]**2)
347
+ grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2)
348
+
349
+ # NEW: angular offsets
350
+ xx,yy=np.meshgrid(grid,grid)
351
+ s=1.0 # could be modified
352
+ alpha_pix = hp.nside2resol(self.nside, arcmin=False) # ~ taille angulaire typique
353
+ dtheta = (np.sqrt(xx**2+yy**2) * alpha_pix * s).ravel()
354
+ dphi = (np.arctan2(yy,xx)).ravel()
355
+ # local spherical displacement
356
+ # convert to unit vectors
357
+ x = np.sin(dtheta) * np.cos(dphi)
358
+ y = np.sin(dtheta) * np.sin(dphi)
359
+ z = np.cos(dtheta)
360
+ #print(self.nside*x.reshape(self.KERNELSZ,self.KERNELSZ))
361
+ #print(self.nside*y.reshape(self.KERNELSZ,self.KERNELSZ))
362
+ #print(self.nside*z.reshape(self.KERNELSZ,self.KERNELSZ))
363
+ vec_np = np.stack([x, y, z], axis=-1)
364
+
365
+ #vec_np[:, 0] = np.tile(grid, self.KERNELSZ)
366
+ #vec_np[:, 1] = np.repeat(grid, self.KERNELSZ)
367
+ #vec_np[:, 2] = 1.0 - np.sqrt(vec_np[:, 0]**2 + vec_np[:, 1]**2)
361
368
  vec_t = torch.as_tensor(vec_np, device=self.device, dtype=self.dtype) # (P,3)
362
369
 
363
370
  # --- rotation matrices for all targets & gauges: (K,G,3,3)
@@ -371,7 +378,7 @@ class SphericalStencil:
371
378
  th, ph, alpha, G=self.G, gauge_cosmo=(self.gauge_type=='cosmo'),
372
379
  device=self.device, dtype=self.dtype
373
380
  ) # shape (K,G,3,3)
374
-
381
+
375
382
  # --- rotate stencil for each (target, gauge): (K,G,P,3)
376
383
  # einsum over local stencil (P,3) with rotation (K,G,3,3)
377
384
  rotated = torch.einsum('kgij,pj->kgpi', R_t, vec_t) # (K,G,P,3)
@@ -568,119 +575,6 @@ class SphericalStencil:
568
575
  self.dtype = dtype
569
576
 
570
577
 
571
- '''
572
- def bind_support_torch_multi(self, ids_sorted_np, *, device=None, dtype=None):
573
- """
574
- Multi-gauge sparse binding (Step B).
575
- Uses self.idx_t_multi / self.w_t_multi prepared by prepare_torch(..., G>1)
576
- and builds, for each gauge g, (pos_safe, w_norm, present).
577
-
578
- Parameters
579
- ----------
580
- ids_sorted_np : np.ndarray (K,)
581
- Sorted pixel ids for available samples (matches the last axis of your data).
582
- device, dtype : torch device/dtype for the produced mapping tensors.
583
-
584
- Side effects
585
- ------------
586
- Sets:
587
- - self.ids_sorted_np : (K,)
588
- - self.pos_safe_t_multi : (G, 4, K*P) LongTensor
589
- - self.w_norm_t_multi : (G, 4, K*P) Tensor
590
- - self.present_t_multi : (G, 4, K*P) BoolTensor
591
- - (and mirrors device/dtype in self.device/self.dtype)
592
- """
593
- assert hasattr(self, 'idx_t_multi') and self.idx_t_multi is not None, \
594
- "Call prepare_torch(..., G>0) before bind_support_torch_multi(...)"
595
- assert hasattr(self, 'w_t_multi') and self.w_t_multi is not None
596
-
597
- if device is None: device = self.device
598
- if dtype is None: dtype = self.dtype
599
-
600
- self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64).reshape(-1)
601
- ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
602
-
603
- G, _, M = self.idx_t_multi.shape
604
- K = self.Kb
605
- P = self.P
606
- assert M == K*P, "idx_t_multi second axis must have K*P columns"
607
-
608
- pos_list, present_list, wnorm_list = [], [], []
609
-
610
- for g in range(G):
611
- idx = self.idx_t_multi[g].to(device=device, dtype=torch.long) # (4, M)
612
- w = self.w_t_multi[g].to(device=device, dtype=dtype) # (4, M)
613
-
614
- pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
615
- in_range = pos < ids_sorted.numel()
616
- cmp_vals = torch.full_like(idx, -1)
617
- cmp_vals[in_range] = ids_sorted[pos[in_range]]
618
- present = (cmp_vals == idx)
619
-
620
- # normalize weights per column after masking
621
- w = w * present
622
- colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
623
- w_norm = w / colsum
624
-
625
- pos_safe = torch.where(present, pos, torch.zeros_like(pos))
626
-
627
- pos_list.append(pos_safe)
628
- present_list.append(present)
629
- wnorm_list.append(w_norm)
630
-
631
- self.pos_safe_t_multi = torch.stack(pos_list, dim=0) # (G, 4, M)
632
- self.present_t_multi = torch.stack(present_list, dim=0) # (G, 4, M)
633
- self.w_norm_t_multi = torch.stack(wnorm_list, dim=0) # (G, 4, M)
634
-
635
- # mirror runtime placement
636
- self.device = device
637
- self.dtype = dtype
638
-
639
- # ------------------------------------------------------------------
640
- # Step B: bind support Torch
641
- # ------------------------------------------------------------------
642
- def bind_support_torch(self, ids_sorted_np, *, device=None, dtype=None):
643
- """
644
- Map HEALPix neighbor indices (from Step A) to actual data samples
645
- sorted by pixel id. Produces pos_safe and normalized weights.
646
-
647
- Parameters
648
- ----------
649
- ids_sorted_np : np.ndarray (K,)
650
- Sorted pixel ids for available data.
651
- device, dtype : Torch device/dtype for results.
652
- """
653
- if device is None:
654
- device = self.device
655
- if dtype is None:
656
- dtype = self.dtype
657
-
658
- self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64)
659
- ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
660
-
661
- idx = self.idx_t.to(device=device, dtype=torch.long)
662
- w = self.w_t.to(device=device, dtype=dtype)
663
-
664
- M = self.Kb * self.P
665
- idx = idx.view(4, M)
666
- w = w.view(4, M)
667
-
668
- pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
669
- in_range = pos < ids_sorted.shape[0]
670
- cmp_vals = torch.full_like(idx, -1)
671
- cmp_vals[in_range] = ids_sorted[pos[in_range]]
672
- present = (cmp_vals == idx)
673
-
674
- w = w * present
675
- colsum = w.sum(dim=0, keepdim=True).clamp_min(1e-12)
676
- w_norm = w / colsum
677
-
678
- self.pos_safe_t = torch.where(present, pos, torch.zeros_like(pos))
679
- self.w_norm_t = w_norm
680
- self.present_t = present
681
- self.device = device
682
- self.dtype = dtype
683
- '''
684
578
  # ------------------------------------------------------------------
685
579
  # Step C: apply convolution (already Torch in your code)
686
580
  # ------------------------------------------------------------------
@@ -1215,7 +1109,7 @@ class SphericalStencil:
1215
1109
  vals = torch.cat(vals_all, dim=0)
1216
1110
 
1217
1111
 
1218
- indices = torch.stack([cols, rows], dim=0) # (2, nnz) invert rows/cols for foscat needs
1112
+ indices = torch.stack([cols, rows], dim=0)
1219
1113
 
1220
1114
  if return_sparse_tensor:
1221
1115
  M = torch.sparse_coo_tensor(indices, vals, size=shape, device=device, dtype=k_dtype).coalesce()
@@ -1224,123 +1118,10 @@ class SphericalStencil:
1224
1118
  return vals, indices, shape
1225
1119
 
1226
1120
 
1227
- def _to_numpy_1d(self, ids):
1228
- """Return a 1D numpy array of int64 for a single set of cell ids."""
1229
- import numpy as np, torch
1230
- if isinstance(ids, np.ndarray):
1231
- return ids.reshape(-1).astype(np.int64, copy=False)
1232
- if torch.is_tensor(ids):
1233
- return ids.detach().cpu().to(torch.long).view(-1).numpy()
1234
- # python list/tuple of ints
1235
- return np.asarray(ids, dtype=np.int64).reshape(-1)
1236
-
1237
- def _is_varlength_batch(self, ids):
1238
- """
1239
- True if ids is a list/tuple of per-sample id arrays (var-length batch).
1240
- False if ids is a single array/tensor of ids (shared for whole batch).
1241
- """
1242
- import numpy as np, torch
1243
- if isinstance(ids, (list, tuple)):
1244
- return True
1245
- if isinstance(ids, np.ndarray) and ids.ndim == 2:
1246
- # This would be a dense (B, Npix) matrix -> NOT var-length list
1247
- return False
1248
- if torch.is_tensor(ids) and ids.dim() == 2:
1249
- return False
1250
- return False
1251
-
1252
- def Down(self, im, cell_ids=None, nside=None,max_poll=False):
1253
- """
1254
- If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
1255
- If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
1256
- """
1257
- if self.f is None:
1258
- if self.dtype==torch.float64:
1259
- self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
1260
- else:
1261
- self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
1262
-
1263
- if cell_ids is None:
1264
- dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=max_poll)
1265
- return dim,cdim
1266
-
1267
- if nside is None:
1268
- nside = self.nside
1269
-
1270
- # var-length mode: list/tuple of ids, one per sample
1271
- if self._is_varlength_batch(cell_ids):
1272
- outs, outs_ids = [], []
1273
- B = len(cell_ids)
1274
- for b in range(B):
1275
- cid_b = self._to_numpy_1d(cell_ids[b])
1276
- # extraire le bon échantillon d'`im`
1277
- if torch.is_tensor(im):
1278
- xb = im[b:b+1] # (1, C, N_b)
1279
- yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
1280
- outs.append(yb.squeeze(0)) # (C, N_b')
1281
- else:
1282
- # si im est déjà une liste de (C, N_b)
1283
- xb = im[b]
1284
- yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
1285
- outs.append(yb.squeeze(0))
1286
- outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
1287
- return outs, outs_ids
1288
-
1289
- # grille commune (un seul vecteur d'ids)
1290
- cid = self._to_numpy_1d(cell_ids)
1291
- return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
1292
-
1293
- def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
1294
- """
1295
- If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
1296
- If they are lists (var-length per sample) -> return list[Tensor].
1297
- """
1298
- if self.f is None:
1299
- if self.dtype==torch.float64:
1300
- self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
1301
- else:
1302
- self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
1303
-
1304
- if cell_ids is None:
1305
- dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
1306
- return dim
1307
-
1308
- if nside is None:
1309
- nside = self.nside
1310
-
1311
- # var-length: listes parallèles
1312
- if self._is_varlength_batch(cell_ids):
1313
- assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
1314
- "In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
1315
- outs = []
1316
- B = len(cell_ids)
1317
- for b in range(B):
1318
- cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
1319
- ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
1320
- if torch.is_tensor(im):
1321
- xb = im[b:b+1] # (1, C, N_b_coarse)
1322
- yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
1323
- o_cell_ids=ocid_b, force_init_index=True)
1324
- outs.append(yb.squeeze(0)) # (C, N_b_fine)
1325
- else:
1326
- xb = im[b] # (C, N_b_coarse)
1327
- yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
1328
- o_cell_ids=ocid_b, force_init_index=True)
1329
- outs.append(yb.squeeze(0))
1330
- return outs
1331
-
1332
- # grille commune
1333
- cid = self._to_numpy_1d(cell_ids)
1334
- ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
1335
- return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
1336
- o_cell_ids=ocid, force_init_index=True)
1337
-
1338
-
1339
1121
  def to_tensor(self,x):
1340
- return torch.tensor(x,device=self.device,dtype=self.dtype)
1341
-
1122
+ return torch.tensor(x,device='cuda')
1123
+
1342
1124
  def to_numpy(self,x):
1343
1125
  if isinstance(x,np.ndarray):
1344
1126
  return x
1345
- return x.cpu().numpy()
1346
-
1127
+ return x.cpu().numpy()
foscat/alm_loc.py ADDED
@@ -0,0 +1,270 @@
1
+
2
+ import numpy as np
3
+ import healpy as hp
4
+
5
+ from foscat.alm import alm as _alm
6
+ import torch
7
+
8
+ class alm_loc(_alm):
9
+ """
10
+ Local/partial-sky variant of foscat.alm.alm.
11
+
12
+ Key design choice (to match alm.py exactly when full-sky is provided):
13
+ - Reuse *all* Legendre/normalization machinery from the parent class (alm),
14
+ i.e. shift_ph(), compute_legendre_m(), ratio_mm, A/B recurrences, etc.
15
+ This is critical for matching alm.map2alm() numerically.
16
+
17
+ Differences vs alm.map2alm():
18
+ - Input map is [..., n] with explicit (nside, cell_ids)
19
+ - Only rings touched by cell_ids are processed.
20
+ - For rings with full coverage, we run the exact same FFT+tiling logic as alm.comp_tf()
21
+ (but only for those rings) -> bitwise comparable up to backend FFT differences.
22
+ - For rings with partial coverage, we compute a *partial DFT* for m=0..mmax,
23
+ using the same phase convention as alm.comp_tf():
24
+ FFT kernel uses exp(-i 2pi (m mod Nring) j / Nring)
25
+ then apply the per-ring shift exp(-i m phi0) via self.matrix_shift_ph
26
+ """
27
+
28
+ def __init__(self, backend=None, lmax=24, limit_range=1e10):
29
+ super().__init__(backend=backend, lmax=lmax, nside=None, limit_range=limit_range)
30
+
31
+ # --------- helpers: ring layout identical to alm.ring_th/ring_ph ----------
32
+ @staticmethod
33
+ def _ring_starts_sizes(nside: int):
34
+ starts = []
35
+ sizes = []
36
+ n = 0
37
+ for k in range(nside - 1):
38
+ N = 4 * (k + 1)
39
+ starts.append(n); sizes.append(N)
40
+ n += N
41
+ for _ in range(2 * nside + 1):
42
+ N = 4 * nside
43
+ starts.append(n); sizes.append(N)
44
+ n += N
45
+ for k in range(nside - 1):
46
+ N = 4 * (nside - 1 - k)
47
+ starts.append(n); sizes.append(N)
48
+ n += N
49
+ return np.asarray(starts, np.int64), np.asarray(sizes, np.int32)
50
+
51
+ def _to_ring_ids(self, nside: int, cell_ids: np.ndarray, nest: bool) -> np.ndarray:
52
+ if nest:
53
+ return hp.nest2ring(nside, cell_ids)
54
+ return cell_ids
55
+
56
+ def _group_by_ring(self, nside: int, ring_ids: np.ndarray):
57
+ """
58
+ Returns:
59
+ ring_idx: ring number (0..4*nside-2) per pixel
60
+ pos: position along ring (0..Nring-1) per pixel
61
+ order: sort order grouping by ring then pos
62
+ starts,sizes: ring layout
63
+ """
64
+ starts, sizes = self._ring_starts_sizes(nside)
65
+
66
+ # ring index = last start <= ring_id
67
+ ring_idx = np.searchsorted(starts, ring_ids, side="right") - 1
68
+ ring_idx = ring_idx.astype(np.int32)
69
+
70
+ pos = (ring_ids - starts[ring_idx]).astype(np.int32)
71
+
72
+ order = np.lexsort((pos, ring_idx))
73
+ return ring_idx, pos, order, starts, sizes
74
+
75
+ # ------------------ local Fourier transform per ring ---------------------
76
+ def comp_tf_loc(self, im, nside: int, cell_ids, nest: bool = False, realfft: bool = True, mmax=None):
77
+ """
78
+ Returns:
79
+ rings_used: 1D np.ndarray of ring indices present
80
+ ft: backend tensor of shape [..., nrings_used, mmax+1] (complex)
81
+ where last axis is m, ring axis matches rings_used order.
82
+ """
83
+ nside = int(nside)
84
+ cell_ids = np.asarray(cell_ids, dtype=np.int64)
85
+ if mmax is None:
86
+ mmax = min(self.lmax, 3 * nside - 1)
87
+ mmax = int(mmax)
88
+
89
+ # Ensure parent caches for this nside exist (matrix_shift_ph, A/B, ratio_mm, etc.)
90
+ self.shift_ph(nside)
91
+
92
+ ring_ids = self._to_ring_ids(nside, cell_ids, nest)
93
+ ring_idx, pos, order, starts, sizes = self._group_by_ring(nside, ring_ids)
94
+
95
+ ring_idx = ring_idx[order]
96
+ pos = pos[order]
97
+
98
+ i_im = self.backend.bk_cast(im)
99
+ i_im = self.backend.bk_gather(i_im, order, axis=-1) # reorder last axis
100
+
101
+ rings_used, start_ptr, counts = np.unique(ring_idx, return_index=True, return_counts=True)
102
+
103
+ # Build output per ring as list then concat
104
+ out_per_ring = []
105
+ for r, s0, cnt in zip(rings_used.tolist(), start_ptr.tolist(), counts.tolist()):
106
+ Nring = int(sizes[r])
107
+ p = pos[s0:s0+cnt]
108
+
109
+ v = self.backend.bk_gather(i_im, np.arange(s0, s0+cnt, dtype=np.int64), axis=-1)
110
+
111
+ if cnt == Nring:
112
+ # Full ring: exact same FFT+tiling logic as alm.comp_tf for 1 ring
113
+ # Need data ordered by pos (already grouped, but ensure pos is 0..N-1)
114
+ if not np.all(p == np.arange(Nring, dtype=p.dtype)):
115
+ # reorder within ring
116
+ sub_order = np.argsort(p)
117
+ v = self.backend.bk_gather(v, sub_order, axis=-1)
118
+
119
+ if realfft:
120
+ tmp = self.rfft2fft(v)
121
+ else:
122
+ tmp = self.backend.bk_fft(v)
123
+
124
+ l_n = tmp.shape[-1]
125
+ if l_n < mmax + 1:
126
+ repeat_n = (mmax // l_n) + 1
127
+ tmp = self.backend.bk_tile(tmp, repeat_n, axis=-1)
128
+
129
+ tmp = tmp[..., :mmax+1]
130
+
131
+ # Apply per-ring shift exp(-i m phi0) exactly like alm.comp_tf
132
+ shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m]
133
+ tmp = tmp * shift
134
+ out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m]
135
+ else:
136
+ # Partial ring: partial DFT for required m, using same aliasing as FFT branch
137
+ m_vec = np.arange(mmax+1, dtype=np.int64)
138
+ m_mod = (m_vec % Nring).astype(np.int64)
139
+
140
+ # angles: 2pi * pos * m_mod / Nring
141
+ ang = (2.0 * np.pi / Nring) * p.astype(np.float64)[:, None] * m_mod[None, :].astype(np.float64)
142
+ ker = np.exp(-1j * ang).astype(np.complex128) # [cnt, m]
143
+
144
+ ker_bk = self.backend.bk_cast(ker)
145
+
146
+ # v is [..., cnt]; we want [..., m] = sum_cnt v*ker
147
+ tmp = self.backend.bk_reduce_sum(
148
+ self.backend.bk_expand_dims(v, axis=-1) * ker_bk,
149
+ axis=-2
150
+ ) # [..., m]
151
+
152
+ shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m] true m shift
153
+ tmp = tmp * shift
154
+ out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m]
155
+
156
+ ft = self.backend.bk_concat(out_per_ring, axis=-2) # [..., nrings, m]
157
+ return np.asarray(rings_used, dtype=np.int32), ft
158
+
159
+ # ---------------------------- map -> alm --------------------------------
160
+ def map2alm_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
161
+ nside = int(nside)
162
+ if lmax is None:
163
+ lmax = min(self.lmax, 3 * nside - 1)
164
+ lmax = int(lmax)
165
+
166
+ # Ensure a batch dimension like alm.map2alm expects
167
+ _added_batch = False
168
+ if hasattr(im, 'ndim') and im.ndim == 1:
169
+ im = im[None, :]
170
+ _added_batch = True
171
+ elif (not hasattr(im, 'ndim')) and len(im.shape) == 1:
172
+ im = im[None, :]
173
+ _added_batch = True
174
+
175
+ rings_used, ft = self.comp_tf_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, realfft=True, mmax=lmax)
176
+
177
+ # cos(theta) on used rings
178
+ co_th = np.cos(self.ring_th(nside)[rings_used])
179
+
180
+ # ft is [..., R, m]
181
+ alm_out = None
182
+
183
+
184
+
185
+ for m in range(lmax + 1):
186
+ # IMPORTANT: reuse alm.compute_legendre_m and its normalization exactly
187
+ plm = self.compute_legendre_m(co_th, m, lmax, nside) / (12 * nside**2) # [L,R]
188
+ plm_bk = self.backend.bk_cast(plm)
189
+
190
+ ft_m = ft[..., :, m] # [..., R]
191
+ tmp = self.backend.bk_reduce_sum(
192
+ self.backend.bk_expand_dims(ft_m, axis=-2) * plm_bk,
193
+ axis=-1
194
+ ) # [..., L]
195
+ l_vals = np.arange(m, lmax + 1, dtype=np.float64)
196
+ scale = np.sqrt(2.0 * l_vals + 1.0)
197
+
198
+ # convertir scale en backend tensor (torch) sur le bon device
199
+ scale_t = self.backend.bk_cast(scale) # ou un helper équivalent
200
+ # reshape pour broadcast si nécessaire: [1, L] ou [L]
201
+ shape = (1,) * (tmp.ndim - 1) + (scale_t.shape[0],)
202
+ scale_t = scale_t.reshape(shape)
203
+
204
+ tmp = tmp * scale_t
205
+ if m == 0:
206
+ alm_out = tmp
207
+ else:
208
+ alm_out = self.backend.bk_concat([alm_out, tmp], axis=-1)
209
+ if _added_batch:
210
+ alm_out = alm_out[0]
211
+ return alm_out
212
+
213
+ # ---------------------------- alm -> Cl ---------------------------------
214
+ def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
215
+
216
+ if lmax is None:
217
+ lmax = min(self.lmax, 3 * nside - 1)
218
+ lmax = int(lmax)
219
+
220
+ alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax)
221
+
222
+ # cl has same batch dims as alm, plus ell dim
223
+ batch_shape = alm.shape[:-1]
224
+ cl = torch.zeros(batch_shape + (lmax + 1,), dtype=torch.float64, device=alm.device)
225
+
226
+ idx = 0
227
+ for m in range(lmax + 1):
228
+ L = lmax - m + 1
229
+ a = alm[..., idx:idx+L] # shape: batch + (L,)
230
+ idx += L
231
+
232
+ p = self.backend.bk_real(a * self.backend.bk_conjugate(a)) # batch + (L,)
233
+
234
+ if m == 0:
235
+ cl[..., m:] += p
236
+ else:
237
+ cl[..., m:] += 2.0 * p
238
+
239
+ # divide by (2l+1), broadcast over batch dims
240
+ denom = (2 * torch.arange(lmax + 1, dtype=cl.dtype, device=alm.device) + 1) # (lmax+1,)
241
+ denom = denom.reshape((1,) * len(batch_shape) + (lmax + 1,)) # batch-broadcast
242
+ cl = cl / denom
243
+ return cl
244
+ '''
245
+ def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
246
+ if lmax is None:
247
+ lmax = min(self.lmax, 3 * nside - 1)
248
+ lmax = int(lmax)
249
+
250
+ alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax)
251
+
252
+ # Unpack and compute Cl with correct real-field folding:
253
+ cl = torch.zeros((lmax + 1,), dtype=alm.dtype, device=alm.device)
254
+
255
+ idx = 0
256
+ for m in range(lmax + 1):
257
+ L = lmax - m + 1
258
+ a = alm[..., idx:idx+L]
259
+ idx += L
260
+ p = self.backend.bk_real(a * self.backend.bk_conjugate(a))
261
+ # sum over any batch dims
262
+ p = self.backend.bk_reduce_sum(p, axis=tuple(range(p.ndim-1))) if p.ndim > 1 else p
263
+ if m == 0:
264
+ cl[m:] += p
265
+ else:
266
+ cl[m:] += 2.0 * p
267
+ denom = (2*torch.arange(lmax+1,dtype=p.dtype, device=alm.device)+1)
268
+ cl = cl / denom
269
+ return cl
270
+ '''
foscat/scat.py CHANGED
@@ -1659,7 +1659,7 @@ class funct(FOC.FoCUS):
1659
1659
  s2j2 = None
1660
1660
  l2_image = None
1661
1661
  for j1 in range(jmax):
1662
- if j1 < jmax - self.OSTEP: # stop to add scales
1662
+ if j1 < jmax: # stop to add scales
1663
1663
  # Convol image along the axis defined by 'axis' using the wavelet defined at
1664
1664
  # the foscat initialisation
1665
1665
  # c_image_real is [....,Npix_j1,....,Norient]
foscat/scat1D.py CHANGED
@@ -1282,7 +1282,7 @@ class funct(FOC.FoCUS):
1282
1282
  l2_image = None
1283
1283
 
1284
1284
  for j1 in range(jmax):
1285
- if j1 < jmax - self.OSTEP: # stop to add scales
1285
+ if j1 < jmax: # stop to add scales
1286
1286
  # Convol image along the axis defined by 'axis' using the wavelet defined at
1287
1287
  # the foscat initialisation
1288
1288
  # c_image_real is [....,Npix_j1,....,Norient]