foscat 2025.11.1__py3-none-any.whl → 2026.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/alm_loc.py ADDED
@@ -0,0 +1,270 @@
1
+
2
+ import numpy as np
3
+ import healpy as hp
4
+
5
+ from foscat.alm import alm as _alm
6
+ import torch
7
+
8
+ class alm_loc(_alm):
9
+ """
10
+ Local/partial-sky variant of foscat.alm.alm.
11
+
12
+ Key design choice (to match alm.py exactly when full-sky is provided):
13
+ - Reuse *all* Legendre/normalization machinery from the parent class (alm),
14
+ i.e. shift_ph(), compute_legendre_m(), ratio_mm, A/B recurrences, etc.
15
+ This is critical for matching alm.map2alm() numerically.
16
+
17
+ Differences vs alm.map2alm():
18
+ - Input map is [..., n] with explicit (nside, cell_ids)
19
+ - Only rings touched by cell_ids are processed.
20
+ - For rings with full coverage, we run the exact same FFT+tiling logic as alm.comp_tf()
21
+ (but only for those rings) -> bitwise comparable up to backend FFT differences.
22
+ - For rings with partial coverage, we compute a *partial DFT* for m=0..mmax,
23
+ using the same phase convention as alm.comp_tf():
24
+ FFT kernel uses exp(-i 2pi (m mod Nring) j / Nring)
25
+ then apply the per-ring shift exp(-i m phi0) via self.matrix_shift_ph
26
+ """
27
+
28
+ def __init__(self, backend=None, lmax=24, limit_range=1e10):
29
+ super().__init__(backend=backend, lmax=lmax, nside=None, limit_range=limit_range)
30
+
31
+ # --------- helpers: ring layout identical to alm.ring_th/ring_ph ----------
32
+ @staticmethod
33
+ def _ring_starts_sizes(nside: int):
34
+ starts = []
35
+ sizes = []
36
+ n = 0
37
+ for k in range(nside - 1):
38
+ N = 4 * (k + 1)
39
+ starts.append(n); sizes.append(N)
40
+ n += N
41
+ for _ in range(2 * nside + 1):
42
+ N = 4 * nside
43
+ starts.append(n); sizes.append(N)
44
+ n += N
45
+ for k in range(nside - 1):
46
+ N = 4 * (nside - 1 - k)
47
+ starts.append(n); sizes.append(N)
48
+ n += N
49
+ return np.asarray(starts, np.int64), np.asarray(sizes, np.int32)
50
+
51
+ def _to_ring_ids(self, nside: int, cell_ids: np.ndarray, nest: bool) -> np.ndarray:
52
+ if nest:
53
+ return hp.nest2ring(nside, cell_ids)
54
+ return cell_ids
55
+
56
+ def _group_by_ring(self, nside: int, ring_ids: np.ndarray):
57
+ """
58
+ Returns:
59
+ ring_idx: ring number (0..4*nside-2) per pixel
60
+ pos: position along ring (0..Nring-1) per pixel
61
+ order: sort order grouping by ring then pos
62
+ starts,sizes: ring layout
63
+ """
64
+ starts, sizes = self._ring_starts_sizes(nside)
65
+
66
+ # ring index = last start <= ring_id
67
+ ring_idx = np.searchsorted(starts, ring_ids, side="right") - 1
68
+ ring_idx = ring_idx.astype(np.int32)
69
+
70
+ pos = (ring_ids - starts[ring_idx]).astype(np.int32)
71
+
72
+ order = np.lexsort((pos, ring_idx))
73
+ return ring_idx, pos, order, starts, sizes
74
+
75
+ # ------------------ local Fourier transform per ring ---------------------
76
+ def comp_tf_loc(self, im, nside: int, cell_ids, nest: bool = False, realfft: bool = True, mmax=None):
77
+ """
78
+ Returns:
79
+ rings_used: 1D np.ndarray of ring indices present
80
+ ft: backend tensor of shape [..., nrings_used, mmax+1] (complex)
81
+ where last axis is m, ring axis matches rings_used order.
82
+ """
83
+ nside = int(nside)
84
+ cell_ids = np.asarray(cell_ids, dtype=np.int64)
85
+ if mmax is None:
86
+ mmax = min(self.lmax, 3 * nside - 1)
87
+ mmax = int(mmax)
88
+
89
+ # Ensure parent caches for this nside exist (matrix_shift_ph, A/B, ratio_mm, etc.)
90
+ self.shift_ph(nside)
91
+
92
+ ring_ids = self._to_ring_ids(nside, cell_ids, nest)
93
+ ring_idx, pos, order, starts, sizes = self._group_by_ring(nside, ring_ids)
94
+
95
+ ring_idx = ring_idx[order]
96
+ pos = pos[order]
97
+
98
+ i_im = self.backend.bk_cast(im)
99
+ i_im = self.backend.bk_gather(i_im, order, axis=-1) # reorder last axis
100
+
101
+ rings_used, start_ptr, counts = np.unique(ring_idx, return_index=True, return_counts=True)
102
+
103
+ # Build output per ring as list then concat
104
+ out_per_ring = []
105
+ for r, s0, cnt in zip(rings_used.tolist(), start_ptr.tolist(), counts.tolist()):
106
+ Nring = int(sizes[r])
107
+ p = pos[s0:s0+cnt]
108
+
109
+ v = self.backend.bk_gather(i_im, np.arange(s0, s0+cnt, dtype=np.int64), axis=-1)
110
+
111
+ if cnt == Nring:
112
+ # Full ring: exact same FFT+tiling logic as alm.comp_tf for 1 ring
113
+ # Need data ordered by pos (already grouped, but ensure pos is 0..N-1)
114
+ if not np.all(p == np.arange(Nring, dtype=p.dtype)):
115
+ # reorder within ring
116
+ sub_order = np.argsort(p)
117
+ v = self.backend.bk_gather(v, sub_order, axis=-1)
118
+
119
+ if realfft:
120
+ tmp = self.rfft2fft(v)
121
+ else:
122
+ tmp = self.backend.bk_fft(v)
123
+
124
+ l_n = tmp.shape[-1]
125
+ if l_n < mmax + 1:
126
+ repeat_n = (mmax // l_n) + 1
127
+ tmp = self.backend.bk_tile(tmp, repeat_n, axis=-1)
128
+
129
+ tmp = tmp[..., :mmax+1]
130
+
131
+ # Apply per-ring shift exp(-i m phi0) exactly like alm.comp_tf
132
+ shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m]
133
+ tmp = tmp * shift
134
+ out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m]
135
+ else:
136
+ # Partial ring: partial DFT for required m, using same aliasing as FFT branch
137
+ m_vec = np.arange(mmax+1, dtype=np.int64)
138
+ m_mod = (m_vec % Nring).astype(np.int64)
139
+
140
+ # angles: 2pi * pos * m_mod / Nring
141
+ ang = (2.0 * np.pi / Nring) * p.astype(np.float64)[:, None] * m_mod[None, :].astype(np.float64)
142
+ ker = np.exp(-1j * ang).astype(np.complex128) # [cnt, m]
143
+
144
+ ker_bk = self.backend.bk_cast(ker)
145
+
146
+ # v is [..., cnt]; we want [..., m] = sum_cnt v*ker
147
+ tmp = self.backend.bk_reduce_sum(
148
+ self.backend.bk_expand_dims(v, axis=-1) * ker_bk,
149
+ axis=-2
150
+ ) # [..., m]
151
+
152
+ shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m] true m shift
153
+ tmp = tmp * shift
154
+ out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m]
155
+
156
+ ft = self.backend.bk_concat(out_per_ring, axis=-2) # [..., nrings, m]
157
+ return np.asarray(rings_used, dtype=np.int32), ft
158
+
159
+ # ---------------------------- map -> alm --------------------------------
160
+ def map2alm_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
161
+ nside = int(nside)
162
+ if lmax is None:
163
+ lmax = min(self.lmax, 3 * nside - 1)
164
+ lmax = int(lmax)
165
+
166
+ # Ensure a batch dimension like alm.map2alm expects
167
+ _added_batch = False
168
+ if hasattr(im, 'ndim') and im.ndim == 1:
169
+ im = im[None, :]
170
+ _added_batch = True
171
+ elif (not hasattr(im, 'ndim')) and len(im.shape) == 1:
172
+ im = im[None, :]
173
+ _added_batch = True
174
+
175
+ rings_used, ft = self.comp_tf_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, realfft=True, mmax=lmax)
176
+
177
+ # cos(theta) on used rings
178
+ co_th = np.cos(self.ring_th(nside)[rings_used])
179
+
180
+ # ft is [..., R, m]
181
+ alm_out = None
182
+
183
+
184
+
185
+ for m in range(lmax + 1):
186
+ # IMPORTANT: reuse alm.compute_legendre_m and its normalization exactly
187
+ plm = self.compute_legendre_m(co_th, m, lmax, nside) / (12 * nside**2) # [L,R]
188
+ plm_bk = self.backend.bk_cast(plm)
189
+
190
+ ft_m = ft[..., :, m] # [..., R]
191
+ tmp = self.backend.bk_reduce_sum(
192
+ self.backend.bk_expand_dims(ft_m, axis=-2) * plm_bk,
193
+ axis=-1
194
+ ) # [..., L]
195
+ l_vals = np.arange(m, lmax + 1, dtype=np.float64)
196
+ scale = np.sqrt(2.0 * l_vals + 1.0)
197
+
198
+ # convertir scale en backend tensor (torch) sur le bon device
199
+ scale_t = self.backend.bk_cast(scale) # ou un helper équivalent
200
+ # reshape pour broadcast si nécessaire: [1, L] ou [L]
201
+ shape = (1,) * (tmp.ndim - 1) + (scale_t.shape[0],)
202
+ scale_t = scale_t.reshape(shape)
203
+
204
+ tmp = tmp * scale_t
205
+ if m == 0:
206
+ alm_out = tmp
207
+ else:
208
+ alm_out = self.backend.bk_concat([alm_out, tmp], axis=-1)
209
+ if _added_batch:
210
+ alm_out = alm_out[0]
211
+ return alm_out
212
+
213
+ # ---------------------------- alm -> Cl ---------------------------------
214
+ def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
215
+
216
+ if lmax is None:
217
+ lmax = min(self.lmax, 3 * nside - 1)
218
+ lmax = int(lmax)
219
+
220
+ alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax)
221
+
222
+ # cl has same batch dims as alm, plus ell dim
223
+ batch_shape = alm.shape[:-1]
224
+ cl = torch.zeros(batch_shape + (lmax + 1,), dtype=torch.float64, device=alm.device)
225
+
226
+ idx = 0
227
+ for m in range(lmax + 1):
228
+ L = lmax - m + 1
229
+ a = alm[..., idx:idx+L] # shape: batch + (L,)
230
+ idx += L
231
+
232
+ p = self.backend.bk_real(a * self.backend.bk_conjugate(a)) # batch + (L,)
233
+
234
+ if m == 0:
235
+ cl[..., m:] += p
236
+ else:
237
+ cl[..., m:] += 2.0 * p
238
+
239
+ # divide by (2l+1), broadcast over batch dims
240
+ denom = (2 * torch.arange(lmax + 1, dtype=cl.dtype, device=alm.device) + 1) # (lmax+1,)
241
+ denom = denom.reshape((1,) * len(batch_shape) + (lmax + 1,)) # batch-broadcast
242
+ cl = cl / denom
243
+ return cl
244
+ '''
245
+ def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None):
246
+ if lmax is None:
247
+ lmax = min(self.lmax, 3 * nside - 1)
248
+ lmax = int(lmax)
249
+
250
+ alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax)
251
+
252
+ # Unpack and compute Cl with correct real-field folding:
253
+ cl = torch.zeros((lmax + 1,), dtype=alm.dtype, device=alm.device)
254
+
255
+ idx = 0
256
+ for m in range(lmax + 1):
257
+ L = lmax - m + 1
258
+ a = alm[..., idx:idx+L]
259
+ idx += L
260
+ p = self.backend.bk_real(a * self.backend.bk_conjugate(a))
261
+ # sum over any batch dims
262
+ p = self.backend.bk_reduce_sum(p, axis=tuple(range(p.ndim-1))) if p.ndim > 1 else p
263
+ if m == 0:
264
+ cl[m:] += p
265
+ else:
266
+ cl[m:] += 2.0 * p
267
+ denom = (2*torch.arange(lmax+1,dtype=p.dtype, device=alm.device)+1)
268
+ cl = cl / denom
269
+ return cl
270
+ '''