foscat 2025.11.1__py3-none-any.whl → 2026.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/FoCUS.py +71 -16
- foscat/SphereDownGeo.py +380 -0
- foscat/SphereUpGeo.py +175 -0
- foscat/SphericalStencil.py +27 -246
- foscat/alm_loc.py +270 -0
- foscat/healpix_vit_torch-old.py +658 -0
- foscat/scat_cov.py +24 -24
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/METADATA +1 -69
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/RECORD +12 -8
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/WHEEL +1 -1
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.11.1.dist-info → foscat-2026.2.1.dist-info}/top_level.txt +0 -0
foscat/alm_loc.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import healpy as hp
|
|
4
|
+
|
|
5
|
+
from foscat.alm import alm as _alm
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
class alm_loc(_alm):
|
|
9
|
+
"""
|
|
10
|
+
Local/partial-sky variant of foscat.alm.alm.
|
|
11
|
+
|
|
12
|
+
Key design choice (to match alm.py exactly when full-sky is provided):
|
|
13
|
+
- Reuse *all* Legendre/normalization machinery from the parent class (alm),
|
|
14
|
+
i.e. shift_ph(), compute_legendre_m(), ratio_mm, A/B recurrences, etc.
|
|
15
|
+
This is critical for matching alm.map2alm() numerically.
|
|
16
|
+
|
|
17
|
+
Differences vs alm.map2alm():
|
|
18
|
+
- Input map is [..., n] with explicit (nside, cell_ids)
|
|
19
|
+
- Only rings touched by cell_ids are processed.
|
|
20
|
+
- For rings with full coverage, we run the exact same FFT+tiling logic as alm.comp_tf()
|
|
21
|
+
(but only for those rings) -> bitwise comparable up to backend FFT differences.
|
|
22
|
+
- For rings with partial coverage, we compute a *partial DFT* for m=0..mmax,
|
|
23
|
+
using the same phase convention as alm.comp_tf():
|
|
24
|
+
FFT kernel uses exp(-i 2pi (m mod Nring) j / Nring)
|
|
25
|
+
then apply the per-ring shift exp(-i m phi0) via self.matrix_shift_ph
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, backend=None, lmax=24, limit_range=1e10):
|
|
29
|
+
super().__init__(backend=backend, lmax=lmax, nside=None, limit_range=limit_range)
|
|
30
|
+
|
|
31
|
+
# --------- helpers: ring layout identical to alm.ring_th/ring_ph ----------
|
|
32
|
+
@staticmethod
|
|
33
|
+
def _ring_starts_sizes(nside: int):
|
|
34
|
+
starts = []
|
|
35
|
+
sizes = []
|
|
36
|
+
n = 0
|
|
37
|
+
for k in range(nside - 1):
|
|
38
|
+
N = 4 * (k + 1)
|
|
39
|
+
starts.append(n); sizes.append(N)
|
|
40
|
+
n += N
|
|
41
|
+
for _ in range(2 * nside + 1):
|
|
42
|
+
N = 4 * nside
|
|
43
|
+
starts.append(n); sizes.append(N)
|
|
44
|
+
n += N
|
|
45
|
+
for k in range(nside - 1):
|
|
46
|
+
N = 4 * (nside - 1 - k)
|
|
47
|
+
starts.append(n); sizes.append(N)
|
|
48
|
+
n += N
|
|
49
|
+
return np.asarray(starts, np.int64), np.asarray(sizes, np.int32)
|
|
50
|
+
|
|
51
|
+
def _to_ring_ids(self, nside: int, cell_ids: np.ndarray, nest: bool) -> np.ndarray:
|
|
52
|
+
if nest:
|
|
53
|
+
return hp.nest2ring(nside, cell_ids)
|
|
54
|
+
return cell_ids
|
|
55
|
+
|
|
56
|
+
def _group_by_ring(self, nside: int, ring_ids: np.ndarray):
|
|
57
|
+
"""
|
|
58
|
+
Returns:
|
|
59
|
+
ring_idx: ring number (0..4*nside-2) per pixel
|
|
60
|
+
pos: position along ring (0..Nring-1) per pixel
|
|
61
|
+
order: sort order grouping by ring then pos
|
|
62
|
+
starts,sizes: ring layout
|
|
63
|
+
"""
|
|
64
|
+
starts, sizes = self._ring_starts_sizes(nside)
|
|
65
|
+
|
|
66
|
+
# ring index = last start <= ring_id
|
|
67
|
+
ring_idx = np.searchsorted(starts, ring_ids, side="right") - 1
|
|
68
|
+
ring_idx = ring_idx.astype(np.int32)
|
|
69
|
+
|
|
70
|
+
pos = (ring_ids - starts[ring_idx]).astype(np.int32)
|
|
71
|
+
|
|
72
|
+
order = np.lexsort((pos, ring_idx))
|
|
73
|
+
return ring_idx, pos, order, starts, sizes
|
|
74
|
+
|
|
75
|
+
# ------------------ local Fourier transform per ring ---------------------
|
|
76
|
+
def comp_tf_loc(self, im, nside: int, cell_ids, nest: bool = False, realfft: bool = True, mmax=None):
|
|
77
|
+
"""
|
|
78
|
+
Returns:
|
|
79
|
+
rings_used: 1D np.ndarray of ring indices present
|
|
80
|
+
ft: backend tensor of shape [..., nrings_used, mmax+1] (complex)
|
|
81
|
+
where last axis is m, ring axis matches rings_used order.
|
|
82
|
+
"""
|
|
83
|
+
nside = int(nside)
|
|
84
|
+
cell_ids = np.asarray(cell_ids, dtype=np.int64)
|
|
85
|
+
if mmax is None:
|
|
86
|
+
mmax = min(self.lmax, 3 * nside - 1)
|
|
87
|
+
mmax = int(mmax)
|
|
88
|
+
|
|
89
|
+
# Ensure parent caches for this nside exist (matrix_shift_ph, A/B, ratio_mm, etc.)
|
|
90
|
+
self.shift_ph(nside)
|
|
91
|
+
|
|
92
|
+
ring_ids = self._to_ring_ids(nside, cell_ids, nest)
|
|
93
|
+
ring_idx, pos, order, starts, sizes = self._group_by_ring(nside, ring_ids)
|
|
94
|
+
|
|
95
|
+
ring_idx = ring_idx[order]
|
|
96
|
+
pos = pos[order]
|
|
97
|
+
|
|
98
|
+
i_im = self.backend.bk_cast(im)
|
|
99
|
+
i_im = self.backend.bk_gather(i_im, order, axis=-1) # reorder last axis
|
|
100
|
+
|
|
101
|
+
rings_used, start_ptr, counts = np.unique(ring_idx, return_index=True, return_counts=True)
|
|
102
|
+
|
|
103
|
+
# Build output per ring as list then concat
|
|
104
|
+
out_per_ring = []
|
|
105
|
+
for r, s0, cnt in zip(rings_used.tolist(), start_ptr.tolist(), counts.tolist()):
|
|
106
|
+
Nring = int(sizes[r])
|
|
107
|
+
p = pos[s0:s0+cnt]
|
|
108
|
+
|
|
109
|
+
v = self.backend.bk_gather(i_im, np.arange(s0, s0+cnt, dtype=np.int64), axis=-1)
|
|
110
|
+
|
|
111
|
+
if cnt == Nring:
|
|
112
|
+
# Full ring: exact same FFT+tiling logic as alm.comp_tf for 1 ring
|
|
113
|
+
# Need data ordered by pos (already grouped, but ensure pos is 0..N-1)
|
|
114
|
+
if not np.all(p == np.arange(Nring, dtype=p.dtype)):
|
|
115
|
+
# reorder within ring
|
|
116
|
+
sub_order = np.argsort(p)
|
|
117
|
+
v = self.backend.bk_gather(v, sub_order, axis=-1)
|
|
118
|
+
|
|
119
|
+
if realfft:
|
|
120
|
+
tmp = self.rfft2fft(v)
|
|
121
|
+
else:
|
|
122
|
+
tmp = self.backend.bk_fft(v)
|
|
123
|
+
|
|
124
|
+
l_n = tmp.shape[-1]
|
|
125
|
+
if l_n < mmax + 1:
|
|
126
|
+
repeat_n = (mmax // l_n) + 1
|
|
127
|
+
tmp = self.backend.bk_tile(tmp, repeat_n, axis=-1)
|
|
128
|
+
|
|
129
|
+
tmp = tmp[..., :mmax+1]
|
|
130
|
+
|
|
131
|
+
# Apply per-ring shift exp(-i m phi0) exactly like alm.comp_tf
|
|
132
|
+
shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m]
|
|
133
|
+
tmp = tmp * shift
|
|
134
|
+
out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m]
|
|
135
|
+
else:
|
|
136
|
+
# Partial ring: partial DFT for required m, using same aliasing as FFT branch
|
|
137
|
+
m_vec = np.arange(mmax+1, dtype=np.int64)
|
|
138
|
+
m_mod = (m_vec % Nring).astype(np.int64)
|
|
139
|
+
|
|
140
|
+
# angles: 2pi * pos * m_mod / Nring
|
|
141
|
+
ang = (2.0 * np.pi / Nring) * p.astype(np.float64)[:, None] * m_mod[None, :].astype(np.float64)
|
|
142
|
+
ker = np.exp(-1j * ang).astype(np.complex128) # [cnt, m]
|
|
143
|
+
|
|
144
|
+
ker_bk = self.backend.bk_cast(ker)
|
|
145
|
+
|
|
146
|
+
# v is [..., cnt]; we want [..., m] = sum_cnt v*ker
|
|
147
|
+
tmp = self.backend.bk_reduce_sum(
|
|
148
|
+
self.backend.bk_expand_dims(v, axis=-1) * ker_bk,
|
|
149
|
+
axis=-2
|
|
150
|
+
) # [..., m]
|
|
151
|
+
|
|
152
|
+
shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m] true m shift
|
|
153
|
+
tmp = tmp * shift
|
|
154
|
+
out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m]
|
|
155
|
+
|
|
156
|
+
ft = self.backend.bk_concat(out_per_ring, axis=-2) # [..., nrings, m]
|
|
157
|
+
return np.asarray(rings_used, dtype=np.int32), ft
|
|
158
|
+
|
|
159
|
+
# ---------------------------- map -> alm --------------------------------
|
|
160
|
+
def map2alm_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
|
|
161
|
+
nside = int(nside)
|
|
162
|
+
if lmax is None:
|
|
163
|
+
lmax = min(self.lmax, 3 * nside - 1)
|
|
164
|
+
lmax = int(lmax)
|
|
165
|
+
|
|
166
|
+
# Ensure a batch dimension like alm.map2alm expects
|
|
167
|
+
_added_batch = False
|
|
168
|
+
if hasattr(im, 'ndim') and im.ndim == 1:
|
|
169
|
+
im = im[None, :]
|
|
170
|
+
_added_batch = True
|
|
171
|
+
elif (not hasattr(im, 'ndim')) and len(im.shape) == 1:
|
|
172
|
+
im = im[None, :]
|
|
173
|
+
_added_batch = True
|
|
174
|
+
|
|
175
|
+
rings_used, ft = self.comp_tf_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, realfft=True, mmax=lmax)
|
|
176
|
+
|
|
177
|
+
# cos(theta) on used rings
|
|
178
|
+
co_th = np.cos(self.ring_th(nside)[rings_used])
|
|
179
|
+
|
|
180
|
+
# ft is [..., R, m]
|
|
181
|
+
alm_out = None
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
for m in range(lmax + 1):
|
|
186
|
+
# IMPORTANT: reuse alm.compute_legendre_m and its normalization exactly
|
|
187
|
+
plm = self.compute_legendre_m(co_th, m, lmax, nside) / (12 * nside**2) # [L,R]
|
|
188
|
+
plm_bk = self.backend.bk_cast(plm)
|
|
189
|
+
|
|
190
|
+
ft_m = ft[..., :, m] # [..., R]
|
|
191
|
+
tmp = self.backend.bk_reduce_sum(
|
|
192
|
+
self.backend.bk_expand_dims(ft_m, axis=-2) * plm_bk,
|
|
193
|
+
axis=-1
|
|
194
|
+
) # [..., L]
|
|
195
|
+
l_vals = np.arange(m, lmax + 1, dtype=np.float64)
|
|
196
|
+
scale = np.sqrt(2.0 * l_vals + 1.0)
|
|
197
|
+
|
|
198
|
+
# convertir scale en backend tensor (torch) sur le bon device
|
|
199
|
+
scale_t = self.backend.bk_cast(scale) # ou un helper équivalent
|
|
200
|
+
# reshape pour broadcast si nécessaire: [1, L] ou [L]
|
|
201
|
+
shape = (1,) * (tmp.ndim - 1) + (scale_t.shape[0],)
|
|
202
|
+
scale_t = scale_t.reshape(shape)
|
|
203
|
+
|
|
204
|
+
tmp = tmp * scale_t
|
|
205
|
+
if m == 0:
|
|
206
|
+
alm_out = tmp
|
|
207
|
+
else:
|
|
208
|
+
alm_out = self.backend.bk_concat([alm_out, tmp], axis=-1)
|
|
209
|
+
if _added_batch:
|
|
210
|
+
alm_out = alm_out[0]
|
|
211
|
+
return alm_out
|
|
212
|
+
|
|
213
|
+
# ---------------------------- alm -> Cl ---------------------------------
|
|
214
|
+
def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
|
|
215
|
+
|
|
216
|
+
if lmax is None:
|
|
217
|
+
lmax = min(self.lmax, 3 * nside - 1)
|
|
218
|
+
lmax = int(lmax)
|
|
219
|
+
|
|
220
|
+
alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax)
|
|
221
|
+
|
|
222
|
+
# cl has same batch dims as alm, plus ell dim
|
|
223
|
+
batch_shape = alm.shape[:-1]
|
|
224
|
+
cl = torch.zeros(batch_shape + (lmax + 1,), dtype=torch.float64, device=alm.device)
|
|
225
|
+
|
|
226
|
+
idx = 0
|
|
227
|
+
for m in range(lmax + 1):
|
|
228
|
+
L = lmax - m + 1
|
|
229
|
+
a = alm[..., idx:idx+L] # shape: batch + (L,)
|
|
230
|
+
idx += L
|
|
231
|
+
|
|
232
|
+
p = self.backend.bk_real(a * self.backend.bk_conjugate(a)) # batch + (L,)
|
|
233
|
+
|
|
234
|
+
if m == 0:
|
|
235
|
+
cl[..., m:] += p
|
|
236
|
+
else:
|
|
237
|
+
cl[..., m:] += 2.0 * p
|
|
238
|
+
|
|
239
|
+
# divide by (2l+1), broadcast over batch dims
|
|
240
|
+
denom = (2 * torch.arange(lmax + 1, dtype=cl.dtype, device=alm.device) + 1) # (lmax+1,)
|
|
241
|
+
denom = denom.reshape((1,) * len(batch_shape) + (lmax + 1,)) # batch-broadcast
|
|
242
|
+
cl = cl / denom
|
|
243
|
+
return cl
|
|
244
|
+
'''
|
|
245
|
+
def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
|
|
246
|
+
if lmax is None:
|
|
247
|
+
lmax = min(self.lmax, 3 * nside - 1)
|
|
248
|
+
lmax = int(lmax)
|
|
249
|
+
|
|
250
|
+
alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax)
|
|
251
|
+
|
|
252
|
+
# Unpack and compute Cl with correct real-field folding:
|
|
253
|
+
cl = torch.zeros((lmax + 1,), dtype=alm.dtype, device=alm.device)
|
|
254
|
+
|
|
255
|
+
idx = 0
|
|
256
|
+
for m in range(lmax + 1):
|
|
257
|
+
L = lmax - m + 1
|
|
258
|
+
a = alm[..., idx:idx+L]
|
|
259
|
+
idx += L
|
|
260
|
+
p = self.backend.bk_real(a * self.backend.bk_conjugate(a))
|
|
261
|
+
# sum over any batch dims
|
|
262
|
+
p = self.backend.bk_reduce_sum(p, axis=tuple(range(p.ndim-1))) if p.ndim > 1 else p
|
|
263
|
+
if m == 0:
|
|
264
|
+
cl[m:] += p
|
|
265
|
+
else:
|
|
266
|
+
cl[m:] += 2.0 * p
|
|
267
|
+
denom = (2*torch.arange(lmax+1,dtype=p.dtype, device=alm.device)+1)
|
|
268
|
+
cl = cl / denom
|
|
269
|
+
return cl
|
|
270
|
+
'''
|