foscat 2025.7.2__py3-none-any.whl → 2025.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +34 -3
- foscat/CNN.py +1 -0
- foscat/FoCUS.py +387 -165
- foscat/HOrientedConvol.py +546 -0
- foscat/HealSpline.py +8 -5
- foscat/Synthesis.py +27 -18
- foscat/UNET.py +200 -0
- foscat/scat_cov.py +289 -178
- foscat/scat_cov_map2D.py +1 -1
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/METADATA +1 -1
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/RECORD +14 -12
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/WHEEL +0 -0
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import healpy as hp
|
|
4
|
+
from scipy.sparse import csr_array
|
|
5
|
+
import torch
|
|
6
|
+
from scipy.spatial import cKDTree
|
|
7
|
+
|
|
8
|
+
class HOrientedConvol:
|
|
9
|
+
def __init__(self,nside,KERNELSZ,cell_ids=None,nest=True):
|
|
10
|
+
|
|
11
|
+
if KERNELSZ % 2 == 0:
|
|
12
|
+
raise ValueError(f"N must be odd so that coordinates are integers from -K..K; got N={KERNELSZ}.")
|
|
13
|
+
|
|
14
|
+
self.local_test=False
|
|
15
|
+
|
|
16
|
+
if cell_ids is None:
|
|
17
|
+
self.cell_ids=np.arange(12*nside**2)
|
|
18
|
+
|
|
19
|
+
idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
|
|
20
|
+
KERNELSZ*KERNELSZ,
|
|
21
|
+
nside,
|
|
22
|
+
nest=nest,
|
|
23
|
+
)
|
|
24
|
+
else:
|
|
25
|
+
try:
|
|
26
|
+
self.cell_ids=cell_ids.cpu().numpy()
|
|
27
|
+
except:
|
|
28
|
+
self.cell_ids=cell_ids
|
|
29
|
+
|
|
30
|
+
self.local_test=True
|
|
31
|
+
|
|
32
|
+
idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
|
|
33
|
+
KERNELSZ*KERNELSZ,
|
|
34
|
+
nside,
|
|
35
|
+
nest=nest,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids,nest=nest)
|
|
40
|
+
|
|
41
|
+
if self.local_test:
|
|
42
|
+
t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
|
|
43
|
+
else:
|
|
44
|
+
t,p = hp.pix2ang(nside,idx_nn,nest=True)
|
|
45
|
+
|
|
46
|
+
vec_orig=hp.ang2vec(t,p)
|
|
47
|
+
|
|
48
|
+
self.vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
|
|
49
|
+
|
|
50
|
+
'''
|
|
51
|
+
if self.local_test:
|
|
52
|
+
idx_nn=self.remap_by_first_column(idx_nn)
|
|
53
|
+
'''
|
|
54
|
+
|
|
55
|
+
del mat_pt
|
|
56
|
+
del vec_orig
|
|
57
|
+
self.t=t[:,0]
|
|
58
|
+
self.p=p[:,0]
|
|
59
|
+
self.idx_nn=idx_nn
|
|
60
|
+
self.nside=nside
|
|
61
|
+
self.KERNELSZ=KERNELSZ
|
|
62
|
+
|
|
63
|
+
def remap_by_first_column(self,idx: np.ndarray) -> np.ndarray:
|
|
64
|
+
"""
|
|
65
|
+
Remap the values in `idx` so that:
|
|
66
|
+
- The first column becomes [0, 1, ..., N-1]
|
|
67
|
+
- All other columns are updated accordingly using the same mapping.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
idx : np.ndarray
|
|
72
|
+
Integer array of shape (N, m).
|
|
73
|
+
Assumes all values in idx are present in the first column (otherwise they get -1).
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
np.ndarray
|
|
78
|
+
New array with remapped indices.
|
|
79
|
+
"""
|
|
80
|
+
if idx.ndim != 2:
|
|
81
|
+
raise ValueError("idx must be a 2D array of shape (N, m)")
|
|
82
|
+
|
|
83
|
+
N, m = idx.shape
|
|
84
|
+
|
|
85
|
+
# Create a mapping: original_value_in_first_column -> row_index
|
|
86
|
+
# Example: if idx[:,0] = [101, 505, 303], then mapping = {101:0, 505:1, 303:2}
|
|
87
|
+
keys = idx[:, 0]
|
|
88
|
+
mapping = {v: i for i, v in enumerate(keys)}
|
|
89
|
+
|
|
90
|
+
# Optional check: ensure all values are in the mapping keys
|
|
91
|
+
# If not, you can raise an error or handle it differently
|
|
92
|
+
# if not np.isin(idx, keys).all():
|
|
93
|
+
# missing = np.unique(idx[~np.isin(idx, keys)])
|
|
94
|
+
# raise ValueError(f"Some values are not in idx[:,0]: {missing}")
|
|
95
|
+
|
|
96
|
+
# Function to get mapped value, or -1 if value is not found
|
|
97
|
+
get = mapping.get
|
|
98
|
+
|
|
99
|
+
# Apply mapping to all elements (vectorized via np.vectorize)
|
|
100
|
+
out = np.vectorize(lambda v: get(int(v), -1), otypes=[int])(idx)
|
|
101
|
+
|
|
102
|
+
return out
|
|
103
|
+
|
|
104
|
+
def rotation_matrices_from_healpix(self,nside, hpix_idx, nest=True):
|
|
105
|
+
"""
|
|
106
|
+
Compute rotation matrices that move each Healpix pixel center to the North pole.
|
|
107
|
+
equivalent to rotation matrices R_z(phi) * R_y(-thi) for N points.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
nside : int
|
|
112
|
+
Healpix Nside resolution.
|
|
113
|
+
hpix_idx : array_like, shape (N,)
|
|
114
|
+
Healpix pixel indices.
|
|
115
|
+
nest : bool, optional
|
|
116
|
+
True if indices are in NESTED ordering, False for RING ordering.
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
R : ndarray, shape (3, 3, N)
|
|
121
|
+
Rotation matrices for each pixel index.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
hpix_idx = np.asarray(hpix_idx)
|
|
126
|
+
except:
|
|
127
|
+
hpix_idx = hpix_idx.cpu().numpy()
|
|
128
|
+
|
|
129
|
+
N = hpix_idx.shape[0]
|
|
130
|
+
|
|
131
|
+
# Get angular coordinates of each pixel center
|
|
132
|
+
theta, phi = hp.pix2ang(nside, hpix_idx, nest=nest) # theta: colatitude (0=north pole)
|
|
133
|
+
|
|
134
|
+
# Precompute sines/cosines
|
|
135
|
+
cphi = np.cos(phi)
|
|
136
|
+
sphi = np.sin(phi)
|
|
137
|
+
cthi = np.cos(-theta)
|
|
138
|
+
sthi = np.sin(-theta)
|
|
139
|
+
|
|
140
|
+
# Rotation around Z (by phi)
|
|
141
|
+
Rz = np.zeros((3, 3, N))
|
|
142
|
+
Rz[0, 0, :] = cphi
|
|
143
|
+
Rz[0, 1, :] = -sphi
|
|
144
|
+
Rz[1, 0, :] = sphi
|
|
145
|
+
Rz[1, 1, :] = cphi
|
|
146
|
+
Rz[2, 2, :] = 1.0
|
|
147
|
+
|
|
148
|
+
# Rotation around Y (by -theta)
|
|
149
|
+
Ry = np.zeros((3, 3, N))
|
|
150
|
+
Ry[0, 0, :] = cthi
|
|
151
|
+
Ry[0, 2, :] = -sthi
|
|
152
|
+
Ry[1, 1, :] = 1.0
|
|
153
|
+
Ry[2, 0, :] = sthi
|
|
154
|
+
Ry[2, 2, :] = cthi
|
|
155
|
+
|
|
156
|
+
# Multiply Rz * Ry for each pixel
|
|
157
|
+
R = np.einsum('ijk,jlk->ilk', Rz, Ry)
|
|
158
|
+
|
|
159
|
+
return R
|
|
160
|
+
|
|
161
|
+
def _choose_depth_for_candidates(self, N, overshoot=2, max_depth=12):
|
|
162
|
+
"""
|
|
163
|
+
Pick hierarchy depth d so that ~ 9 * 4**d >= overshoot * N.
|
|
164
|
+
Depth 0 => 9 candidates; 1 => 36; 2 => 144; 3 => 576; 4 => 2304; etc.
|
|
165
|
+
"""
|
|
166
|
+
d = 0
|
|
167
|
+
while 9 * (4 ** d) < overshoot * N and d < max_depth:
|
|
168
|
+
d += 1
|
|
169
|
+
return d
|
|
170
|
+
|
|
171
|
+
def knn_healpix_ckdtree(self,
|
|
172
|
+
hidx, N, nside, *, nest=True,
|
|
173
|
+
include_self=True,
|
|
174
|
+
vec_dtype=np.float32,
|
|
175
|
+
out_dtype=np.int64
|
|
176
|
+
):
|
|
177
|
+
"""
|
|
178
|
+
k-NN using a cKDTree on unit vectors (exact in Euclidean space).
|
|
179
|
+
Returns LOCAL indices (0..M-1) of the N nearest neighbours per row.
|
|
180
|
+
"""
|
|
181
|
+
try:
|
|
182
|
+
hidx = np.asarray(hidx, dtype=np.int64)
|
|
183
|
+
except:
|
|
184
|
+
hidx = hidx.cpu().numpy()
|
|
185
|
+
|
|
186
|
+
if hidx.ndim != 1:
|
|
187
|
+
raise ValueError("hidx must be 1D")
|
|
188
|
+
M = hidx.size
|
|
189
|
+
if M == 0:
|
|
190
|
+
return np.empty((0, 0), dtype=out_dtype)
|
|
191
|
+
if N <= 0:
|
|
192
|
+
raise ValueError("N must be >= 1")
|
|
193
|
+
|
|
194
|
+
# Effective N
|
|
195
|
+
N_eff = min(N, M if include_self else max(M-1, 1))
|
|
196
|
+
|
|
197
|
+
# Build unit vectors
|
|
198
|
+
hidx_n = hidx if nest else hp.ring2nest(nside, hidx)
|
|
199
|
+
x, y, z = hp.pix2vec(nside, hidx_n, nest=True)
|
|
200
|
+
V = np.stack([x, y, z], axis=1).astype(vec_dtype, copy=False) # (M,3)
|
|
201
|
+
|
|
202
|
+
tree = cKDTree(V)
|
|
203
|
+
|
|
204
|
+
if include_self:
|
|
205
|
+
# Self appears with distance 0 as the first neighbour
|
|
206
|
+
d, idx = tree.query(V, k=N_eff, workers=-1) # idx shape (M,N)
|
|
207
|
+
return idx.astype(out_dtype, copy=False)
|
|
208
|
+
else:
|
|
209
|
+
# Ask for one extra and drop self
|
|
210
|
+
k = min(N_eff + 1, M)
|
|
211
|
+
d, idx = tree.query(V, k=k, workers=-1)
|
|
212
|
+
# idx can be (M,) if k==1; normalize shapes
|
|
213
|
+
if idx.ndim == 1:
|
|
214
|
+
idx = idx[:, None]
|
|
215
|
+
# Remove self if present (distance 0)
|
|
216
|
+
out = np.empty((M, N_eff), dtype=out_dtype)
|
|
217
|
+
for i in range(M):
|
|
218
|
+
row = idx[i]
|
|
219
|
+
# filter out self (i); keep first N_eff
|
|
220
|
+
row = row[row != i][:N_eff]
|
|
221
|
+
# if M==N and no self, row already size N_eff
|
|
222
|
+
out[i, :row.size] = row
|
|
223
|
+
if row.size < N_eff:
|
|
224
|
+
# extremely rare (degenerate duplicates); fallback by scores
|
|
225
|
+
cand = np.setdiff1d(np.arange(M), np.r_[i, row], assume_unique=False)
|
|
226
|
+
# pick nearest remaining
|
|
227
|
+
di, ci = tree.query(V[i], k=N_eff - row.size)
|
|
228
|
+
out[i, row.size:] = np.atleast_1d(ci).astype(out_dtype, copy=False)
|
|
229
|
+
return out
|
|
230
|
+
|
|
231
|
+
def make_wavelet_matrix(self,
|
|
232
|
+
orientations,
|
|
233
|
+
polar=True,
|
|
234
|
+
norm_mean=True,
|
|
235
|
+
norm_std=True,
|
|
236
|
+
return_index=False,
|
|
237
|
+
return_smooth=False,
|
|
238
|
+
):
|
|
239
|
+
|
|
240
|
+
sigma_gauss = 0.5
|
|
241
|
+
sigma_cosine = 0.5
|
|
242
|
+
if self.KERNELSZ == 3:
|
|
243
|
+
sigma_gauss = 1.0 / np.sqrt(2)
|
|
244
|
+
sigma_cosine = 1.0
|
|
245
|
+
|
|
246
|
+
orientations=np.asarray(orientations)
|
|
247
|
+
NORIENT = orientations.shape[0]
|
|
248
|
+
|
|
249
|
+
rotate=2*((self.t<np.pi/2)-0.5)[None,:,None]
|
|
250
|
+
if polar:
|
|
251
|
+
xx=np.cos(self.p[None,:]+np.pi/2-orientations[:,None])[:,:,None]*self.vec_rot[None,:,:,0]-rotate*np.sin(self.p[None,:]+np.pi/2-orientations[:,None])[:,:,None]*self.vec_rot[None,:,:,1]
|
|
252
|
+
else:
|
|
253
|
+
xx=np.cos(np.pi/2-orientations[:,None,None])*self.vec_rot[None,:,:,0]-np.sin(np.pi/2-orientations[:,None,None])*self.vec_rot[None,:,:,1]
|
|
254
|
+
|
|
255
|
+
r=(self.vec_rot[None,:,:,0]**2+self.vec_rot[None,:,:,1]**2+(self.vec_rot[None,:,:,2]-1.0)**2)
|
|
256
|
+
|
|
257
|
+
if return_smooth:
|
|
258
|
+
wsmooth=np.exp(-sigma_gauss*r*self.nside**2)
|
|
259
|
+
if norm_std:
|
|
260
|
+
ww=np.sum(wsmooth,2)
|
|
261
|
+
wsmooth = wsmooth/ww[:,:,None]
|
|
262
|
+
|
|
263
|
+
#for consistency with previous definition
|
|
264
|
+
w=np.exp(-sigma_gauss*r*self.nside**2)*(np.cos(xx*self.nside*sigma_cosine*np.pi)-1J*np.sin(xx*self.nside*sigma_cosine*np.pi))
|
|
265
|
+
|
|
266
|
+
if norm_std:
|
|
267
|
+
ww=1/np.sum(abs(w),2)[:,:,None]
|
|
268
|
+
else:
|
|
269
|
+
ww=1.0
|
|
270
|
+
|
|
271
|
+
if norm_mean:
|
|
272
|
+
w = (w.real-np.mean(w.real,2)[:,:,None]+1J*(w.imag-np.mean(w.imag,2)[:,:,None]))*ww
|
|
273
|
+
|
|
274
|
+
NK=self.idx_nn.shape[1]
|
|
275
|
+
indice_1_0 = np.tile(self.idx_nn.flatten(),NORIENT)
|
|
276
|
+
indice_1_1 = np.tile(np.repeat(self.idx_nn[:,0],NK),NORIENT)+ \
|
|
277
|
+
np.repeat(np.arange(NORIENT),self.idx_nn.shape[0]*self.idx_nn.shape[1])*self.idx_nn.shape[0]
|
|
278
|
+
w = w.flatten()
|
|
279
|
+
|
|
280
|
+
if return_smooth:
|
|
281
|
+
indice_2_0 = self.idx_nn.flatten()
|
|
282
|
+
indice_2_1 = np.repeat(self.idx_nn[:,0],NK)
|
|
283
|
+
wsmooth = wsmooth.flatten()
|
|
284
|
+
|
|
285
|
+
if return_index:
|
|
286
|
+
if return_smooth:
|
|
287
|
+
return w,np.concatenate([indice_1_0[:,None],indice_1_1[:,None]],1),wsmooth,np.concatenate([indice_2_0[:,None],indice_2_1[:,None]],1)
|
|
288
|
+
|
|
289
|
+
return w,np.concatenate([indice_1_0[:,None],indice_1_1[:,None]],1)
|
|
290
|
+
|
|
291
|
+
return csr_array((w, (indice_1_0, indice_1_1)), shape=(12*self.nside**2, 12*self.nside**2*NORIENT))
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True):
|
|
295
|
+
|
|
296
|
+
rotate=2*((self.t<np.pi/2)-0.5)[:,None]
|
|
297
|
+
if polar:
|
|
298
|
+
xx=np.cos(self.p)[:,None]*self.vec_rot[:,:,0]-rotate*np.sin(self.p)[:,None]*self.vec_rot[:,:,1]
|
|
299
|
+
yy=-np.sin(self.p)[:,None]*self.vec_rot[:,:,0]-rotate*np.cos(self.p)[:,None]*self.vec_rot[:,:,1]
|
|
300
|
+
else:
|
|
301
|
+
xx=self.vec_rot[:,:,0]
|
|
302
|
+
yy=self.vec_rot[:,:,1]
|
|
303
|
+
|
|
304
|
+
self.w_idx,self.w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
|
|
305
|
+
yy*self.nside*gamma,
|
|
306
|
+
allow_extrapolation=allow_extrapolation)
|
|
307
|
+
|
|
308
|
+
# Ensure types/devices
|
|
309
|
+
self.idx_nn = torch.Tensor(self.idx_nn).to(device=device, dtype=torch.long)
|
|
310
|
+
self.w_idx = torch.Tensor(self.w_idx).to(device=device, dtype=torch.long)
|
|
311
|
+
self.w_w = torch.Tensor(self.w_w).to(device=device, dtype=torch.float64)
|
|
312
|
+
|
|
313
|
+
def _grid_index(self, xi, yi):
|
|
314
|
+
"""
|
|
315
|
+
Map integer grid coords (xi, yi) in {-1,0,1} to flat index in [0..8]
|
|
316
|
+
following the given order (row-major from y=-1 to y=1).
|
|
317
|
+
"""
|
|
318
|
+
return (yi + self.KERNELSZ//2) * self.KERNELSZ + (xi + self.KERNELSZ//2)
|
|
319
|
+
|
|
320
|
+
def bilinear_weights_NxN(self,x, y, allow_extrapolation=True):
|
|
321
|
+
"""
|
|
322
|
+
Compute bilinear weights on an N×N integer grid with node coordinates
|
|
323
|
+
(xi, yi) in {-K, ..., +K} × {-K, ..., +K}, where K = N//2 (N must be odd).
|
|
324
|
+
|
|
325
|
+
N is attached to the class `N = self.KERNELSZ`
|
|
326
|
+
|
|
327
|
+
The query point (x, y) is continuous in the same coordinate system.
|
|
328
|
+
For each query, we pick the unit cell [x0, x0+1] × [y0, y0+1] with
|
|
329
|
+
integer corners (x0,y0), (x0+1,y0), (x0,y0+1), (x0+1,y0+1), and compute
|
|
330
|
+
standard bilinear weights relative to (x0, y0).
|
|
331
|
+
|
|
332
|
+
Parameters
|
|
333
|
+
----------
|
|
334
|
+
x, y : float or array-like of shape (M,)
|
|
335
|
+
Query coordinates in the integer grid coordinate system.
|
|
336
|
+
N : int
|
|
337
|
+
Grid size (must be odd). Grid nodes are at integer coords
|
|
338
|
+
xi, yi ∈ {-K, ..., +K}, where K = N//2.
|
|
339
|
+
allow_extrapolation : bool, default True
|
|
340
|
+
- If False: clamp (x, y) to [-K, +K] so that tx, ty ∈ [0, 1] and
|
|
341
|
+
weights are non-negative and sum to 1.
|
|
342
|
+
- If True : do not clamp (x, y); we still select the nearest boundary
|
|
343
|
+
cell inside the grid for the indices, but tx, ty may fall outside
|
|
344
|
+
[0, 1], yielding extrapolation (weights can be negative).
|
|
345
|
+
|
|
346
|
+
Returns
|
|
347
|
+
-------
|
|
348
|
+
idx : ndarray of shape (M, 4), dtype=int64
|
|
349
|
+
Flat indices (0 .. N*N-1) of the four cell-corner nodes in row-major
|
|
350
|
+
order (y from -K to +K, x from -K to +K):
|
|
351
|
+
order = [(x0,y0), (x0+1,y0), (x0,y0+1), (x0+1,y0+1)].
|
|
352
|
+
w : ndarray of shape (M, 4), dtype=float64
|
|
353
|
+
Corresponding bilinear weights for each query point. If
|
|
354
|
+
allow_extrapolation=False and the point is inside the grid, each row
|
|
355
|
+
sums to 1 and all weights are in [0,1].
|
|
356
|
+
|
|
357
|
+
Notes
|
|
358
|
+
-----
|
|
359
|
+
- This matches your previous 3×3 case when N=3, with the same row-major
|
|
360
|
+
flattening convention.
|
|
361
|
+
- For extrapolation=True, indices are kept in-bounds (clamped to boundary
|
|
362
|
+
cells), while tx, ty > 1 or < 0 are allowed.
|
|
363
|
+
"""
|
|
364
|
+
# --- checks & shapes ---
|
|
365
|
+
N=self.KERNELSZ
|
|
366
|
+
|
|
367
|
+
K = N // 2
|
|
368
|
+
|
|
369
|
+
x = np.atleast_1d(np.asarray(x, dtype=float))
|
|
370
|
+
y = np.atleast_1d(np.asarray(y, dtype=float))
|
|
371
|
+
if x.shape != y.shape:
|
|
372
|
+
raise ValueError("x and y must have the same shape")
|
|
373
|
+
M = x.shape[0]
|
|
374
|
+
|
|
375
|
+
# --- optionally clamp queries (for pure interpolation) ---
|
|
376
|
+
if not allow_extrapolation:
|
|
377
|
+
x = np.clip(x, -K, K)
|
|
378
|
+
y = np.clip(y, -K, K)
|
|
379
|
+
|
|
380
|
+
# --- choose the cell: x0=floor(x), y0=floor(y), but keep indices in-bounds
|
|
381
|
+
# cell must be inside [-K..K-1] × [-K..K-1] so that +1 is valid
|
|
382
|
+
x0 = np.floor(x)
|
|
383
|
+
y0 = np.floor(y)
|
|
384
|
+
x0 = np.clip(x0, -K, K - 1).astype(int)
|
|
385
|
+
y0 = np.clip(y0, -K, K - 1).astype(int)
|
|
386
|
+
x1 = x0 + 1
|
|
387
|
+
y1 = y0 + 1
|
|
388
|
+
|
|
389
|
+
# --- local coords within the cell (unit spacing) ---
|
|
390
|
+
tx = x - x0
|
|
391
|
+
ty = y - y0
|
|
392
|
+
|
|
393
|
+
# --- bilinear weights ---
|
|
394
|
+
# (x0,y0) w00, (x1,y0) w10, (x0,y1) w01, (x1,y1) w11
|
|
395
|
+
w00 = (1.0 - tx) * (1.0 - ty)
|
|
396
|
+
w10 = tx * (1.0 - ty)
|
|
397
|
+
w01 = (1.0 - tx) * ty
|
|
398
|
+
w11 = tx * ty
|
|
399
|
+
w = np.stack([w00, w10, w01, w11], axis=1)
|
|
400
|
+
|
|
401
|
+
# --- flat indices in row-major order (y changes slowest) ---
|
|
402
|
+
# index = (yi + K) * N + (xi + K)
|
|
403
|
+
def flat_idx(xi, yi):
|
|
404
|
+
return (yi + K) * N + (xi + K)
|
|
405
|
+
|
|
406
|
+
i00 = flat_idx(x0, y0)
|
|
407
|
+
i10 = flat_idx(x1, y0)
|
|
408
|
+
i01 = flat_idx(x0, y1)
|
|
409
|
+
i11 = flat_idx(x1, y1)
|
|
410
|
+
idx = np.stack([i00, i10, i01, i11], axis=1).astype(np.int64)
|
|
411
|
+
|
|
412
|
+
return idx, w
|
|
413
|
+
|
|
414
|
+
def Convol_torch(self, im, ww):
|
|
415
|
+
"""
|
|
416
|
+
Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
|
|
417
|
+
|
|
418
|
+
Parameters
|
|
419
|
+
----------
|
|
420
|
+
im : Tensor, shape (B, C_i, Npix)
|
|
421
|
+
Input features per pixel for a batch of B samples.
|
|
422
|
+
ww : Tensor
|
|
423
|
+
Base mixing weights, indexed along its 'M' dimension by self.w_idx.
|
|
424
|
+
Supported shapes:
|
|
425
|
+
(C_i, C_o, M)
|
|
426
|
+
(C_i, C_o, M, S)
|
|
427
|
+
(B, C_i, C_o, M)
|
|
428
|
+
(B, C_i, C_o, M, S)
|
|
429
|
+
|
|
430
|
+
Class members (already tensors; will be aligned to im.device/dtype):
|
|
431
|
+
-------------------------------------------------------------------
|
|
432
|
+
self.idx_nn : LongTensor, shape (Npix, P)
|
|
433
|
+
For each center pixel, the P neighbor indices into the Npix axis of `im`.
|
|
434
|
+
(P = K*K for a KxK neighborhood.)
|
|
435
|
+
self.w_idx : LongTensor, shape (Npix, P) or (Npix, S, P)
|
|
436
|
+
Indices along the 'M' dimension of ww, per (center[, sector], neighbor).
|
|
437
|
+
self.w_w : Tensor, shape (Npix, P) or (Npix, S, P)
|
|
438
|
+
Additional scalar weights per neighbor (same layout as w_idx).
|
|
439
|
+
|
|
440
|
+
Returns
|
|
441
|
+
-------
|
|
442
|
+
out : Tensor, shape (B, C_o, Npix)
|
|
443
|
+
Aggregated output per center pixel for each batch sample.
|
|
444
|
+
"""
|
|
445
|
+
# ---- Basic checks ----
|
|
446
|
+
assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
|
|
447
|
+
assert ww.shape[2]==self.KERNELSZ*self.KERNELSZ, f"`ww` must be (C_i, C_o, KERNELSZ*KERNELSZ), got {tuple(ww.shape)}"
|
|
448
|
+
|
|
449
|
+
B, C_i, Npix = im.shape
|
|
450
|
+
device = im.device
|
|
451
|
+
dtype = im.dtype
|
|
452
|
+
|
|
453
|
+
# Align class tensors to device/dtype
|
|
454
|
+
idx_nn = self.idx_nn.to(device=device, dtype=torch.long) # (Npix, P)
|
|
455
|
+
w_idx = self.w_idx.to(device=device, dtype=torch.long) # (Npix, P) or (Npix, S, P)
|
|
456
|
+
w_w = self.w_w.to(device=device, dtype=dtype) # (Npix, P) or (Npix, S, P)
|
|
457
|
+
|
|
458
|
+
# Neighbor count P inferred from idx_nn
|
|
459
|
+
assert idx_nn.ndim == 2 and idx_nn.size(0) == Npix, \
|
|
460
|
+
f"`idx_nn` must be (Npix, P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
|
|
461
|
+
P = idx_nn.size(1)
|
|
462
|
+
|
|
463
|
+
# ---- 1) Gather neighbor values from im along the Npix dimension -> (B, C_i, Npix, P)
|
|
464
|
+
# im: (B,C_i,Npix) -> (B,C_i,Npix,1); idx: (1,1,Npix,P) broadcast over (B,C_i)
|
|
465
|
+
rim = torch.take_along_dim(
|
|
466
|
+
im.unsqueeze(-1),
|
|
467
|
+
idx_nn.unsqueeze(0).unsqueeze(0),
|
|
468
|
+
dim=2
|
|
469
|
+
) # (B, C_i, Npix, P)
|
|
470
|
+
|
|
471
|
+
# ---- 2) Normalize w_idx / w_w to include a sector dim S ----
|
|
472
|
+
# Target layout: (Npix, S, P)
|
|
473
|
+
if w_idx.ndim == 2:
|
|
474
|
+
# (Npix, P) -> add sector dim S=1
|
|
475
|
+
assert w_idx.size(0) == Npix and w_idx.size(1) == P
|
|
476
|
+
w_idx_eff = w_idx.unsqueeze(1) # (Npix, 1, P)
|
|
477
|
+
w_w_eff = w_w.unsqueeze(1) # (Npix, 1, P)
|
|
478
|
+
S = 1
|
|
479
|
+
elif w_idx.ndim == 3:
|
|
480
|
+
# (Npix, S, P)
|
|
481
|
+
Npix_, S, P_ = w_idx.shape
|
|
482
|
+
assert Npix_ == Npix and P_ == P, \
|
|
483
|
+
f"`w_idx` must be (Npix,S,P) with Npix={Npix}, P={P}, got {tuple(w_idx.shape)}"
|
|
484
|
+
assert w_w.shape == w_idx.shape, "`w_w` must match `w_idx` shape"
|
|
485
|
+
w_idx_eff = w_idx
|
|
486
|
+
w_w_eff = w_w
|
|
487
|
+
else:
|
|
488
|
+
raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
|
|
489
|
+
|
|
490
|
+
# ---- 3) Normalize ww to (B, C_i, C_o, M, S) for uniform gather ----
|
|
491
|
+
if ww.ndim == 3:
|
|
492
|
+
# (C_i, C_o, M) -> (B, C_i, C_o, M, S)
|
|
493
|
+
C_i_w, C_o, M = ww.shape
|
|
494
|
+
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
495
|
+
ww_eff = ww.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, -1, S)
|
|
496
|
+
|
|
497
|
+
elif ww.ndim == 4:
|
|
498
|
+
# Could be (C_i, C_o, M, S) or (B, C_i, C_o, M)
|
|
499
|
+
if ww.shape[0] == C_i and ww.shape[1] != C_i:
|
|
500
|
+
# (C_i, C_o, M, S) -> (B, C_i, C_o, M, S)
|
|
501
|
+
C_i_w, C_o, M, S_w = ww.shape
|
|
502
|
+
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
503
|
+
assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
|
|
504
|
+
ww_eff = ww.unsqueeze(0).expand(B, -1, -1, -1, -1)
|
|
505
|
+
elif ww.shape[0] == B:
|
|
506
|
+
# (B, C_i, C_o, M) -> (B, C_i, C_o, M, S)
|
|
507
|
+
_, C_i_w, C_o, M = ww.shape
|
|
508
|
+
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
|
|
509
|
+
ww_eff = ww.unsqueeze(-1).expand(-1, -1, -1, -1, S)
|
|
510
|
+
else:
|
|
511
|
+
raise ValueError(
|
|
512
|
+
f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)"
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
elif ww.ndim == 5:
|
|
516
|
+
# (B, C_i, C_o, M, S)
|
|
517
|
+
assert ww.shape[0] == B and ww.shape[1] == C_i, "ww batch/C_i mismatch"
|
|
518
|
+
_, _, _, M, S_w = ww.shape
|
|
519
|
+
assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
|
|
520
|
+
ww_eff = ww
|
|
521
|
+
else:
|
|
522
|
+
raise ValueError(f"Unsupported ww shape {tuple(ww.shape)}")
|
|
523
|
+
|
|
524
|
+
# ---- 4) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
|
|
525
|
+
idx_exp = w_idx_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1,1,1,Npix,S,P)
|
|
526
|
+
rw = torch.take_along_dim(
|
|
527
|
+
ww_eff.unsqueeze(-1), # (B, C_i, C_o, M, S, 1)
|
|
528
|
+
idx_exp, # (1,1,1,Npix,S,P) -> broadcast
|
|
529
|
+
dim=3 # gather along M
|
|
530
|
+
) # -> (B, C_i, C_o, Npix, S, P)
|
|
531
|
+
|
|
532
|
+
# ---- 5) Apply extra neighbor weights ----
|
|
533
|
+
rw = rw * w_w_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (B, C_i, C_o, Npix, S, P)
|
|
534
|
+
|
|
535
|
+
# ---- 6) Combine neighbor values and weights ----
|
|
536
|
+
# rim: (B, C_i, Npix, P) -> expand to (B, C_i, 1, Npix, 1, P)
|
|
537
|
+
rim_exp = rim[:, :, None, :, None, :]
|
|
538
|
+
# sum over neighbors (P), then over sectors (S), then over input channels (C_i)
|
|
539
|
+
out_ci = (rim_exp * rw).sum(dim=-1) # (B, C_i, C_o, Npix, S)
|
|
540
|
+
out_ci = out_ci.sum(dim=-1) # (B, C_i, C_o, Npix)
|
|
541
|
+
out = out_ci.sum(dim=1) # (B, C_o, Npix)
|
|
542
|
+
|
|
543
|
+
return out
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
|
foscat/HealSpline.py
CHANGED
|
@@ -51,7 +51,7 @@ class heal_spline:
|
|
|
51
51
|
kind='cubic', fill_value='extrapolate')
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
def ang2weigths(self,th,ph,threshold=1E-2,nest=
|
|
54
|
+
def ang2weigths(self,th,ph,threshold=1E-2,nest=False):
|
|
55
55
|
th0=self.f_interp_th(th).flatten()
|
|
56
56
|
|
|
57
57
|
idx_lat,w_th=self.spline_lat.eval(th0.flatten())
|
|
@@ -73,11 +73,13 @@ class heal_spline:
|
|
|
73
73
|
www=www.reshape(16,www.shape[2])
|
|
74
74
|
all_idx=all_idx.reshape(16,all_idx.shape[2])
|
|
75
75
|
|
|
76
|
+
if nest:
|
|
77
|
+
all_idx = hp.ring2nest(self.nside,all_idx)
|
|
78
|
+
|
|
76
79
|
heal_idx,inv_idx = np.unique(all_idx,
|
|
77
80
|
return_inverse=True)
|
|
78
81
|
all_idx = inv_idx
|
|
79
|
-
|
|
80
|
-
heal_idx = hp.ring2nest(self.nside,heal_idx)
|
|
82
|
+
|
|
81
83
|
self.cell_ids = heal_idx
|
|
82
84
|
|
|
83
85
|
hit=np.bincount(all_idx.flatten(),weights=www.flatten())
|
|
@@ -190,8 +192,9 @@ class heal_spline:
|
|
|
190
192
|
spl=np.zeros([scale])
|
|
191
193
|
spl[heal_idx[ih==k]-scale*h[k]]=self.spline[ih==k]
|
|
192
194
|
self.spline_tree[h[k]]=spl
|
|
193
|
-
|
|
194
|
-
|
|
195
|
+
|
|
196
|
+
def GetParam(self):
|
|
197
|
+
return self.heal_idx,self.spline
|
|
195
198
|
|
|
196
199
|
def Transform(self,th,ph,threshold=1E-2,nest=True):
|
|
197
200
|
|
foscat/Synthesis.py
CHANGED
|
@@ -195,6 +195,8 @@ class Synthesis:
|
|
|
195
195
|
x = self.operation.backend.bk_reshape(
|
|
196
196
|
self.operation.backend.bk_cast(in_x), self.oshape
|
|
197
197
|
)
|
|
198
|
+
if self.idx_grd is not None:
|
|
199
|
+
x=x[self.idx_grd]
|
|
198
200
|
|
|
199
201
|
self.l_log[
|
|
200
202
|
self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
|
|
@@ -246,6 +248,11 @@ class Synthesis:
|
|
|
246
248
|
|
|
247
249
|
g_tot[np.isnan(g_tot)] = 0.0
|
|
248
250
|
|
|
251
|
+
if self.idx_grd is not None:
|
|
252
|
+
lg_tot=np.zeros(in_x.shape)
|
|
253
|
+
lg_tot[self.idx_grd]=g_tot
|
|
254
|
+
g_tot=lg_tot
|
|
255
|
+
|
|
249
256
|
self.imin = self.imin + self.batchsz
|
|
250
257
|
|
|
251
258
|
if self.mpi_size == 1:
|
|
@@ -295,24 +302,25 @@ class Synthesis:
|
|
|
295
302
|
|
|
296
303
|
# ---------------------------------------------−---------
|
|
297
304
|
def run(
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
305
|
+
self,
|
|
306
|
+
in_x,
|
|
307
|
+
NUM_EPOCHS=100,
|
|
308
|
+
DECAY_RATE=0.95,
|
|
309
|
+
EVAL_FREQUENCY=100,
|
|
310
|
+
DEVAL_STAT_FREQUENCY=1000,
|
|
311
|
+
NUM_STEP_BIAS=1,
|
|
312
|
+
LEARNING_RATE=0.03,
|
|
313
|
+
EPSILON=1e-7,
|
|
314
|
+
KEEP_TRACK=None,
|
|
315
|
+
grd_mask=None,
|
|
316
|
+
SHOWGPU=False,
|
|
317
|
+
MESSAGE="",
|
|
318
|
+
factr=10.0,
|
|
319
|
+
batchsz=1,
|
|
320
|
+
totalsz=1,
|
|
321
|
+
do_lbfgs=True,
|
|
322
|
+
idx_grd=None,
|
|
323
|
+
axis=0,
|
|
316
324
|
):
|
|
317
325
|
|
|
318
326
|
self.KEEP_TRACK = KEEP_TRACK
|
|
@@ -326,6 +334,7 @@ class Synthesis:
|
|
|
326
334
|
self.batchsz = batchsz
|
|
327
335
|
self.totalsz = totalsz
|
|
328
336
|
self.grd_mask = grd_mask
|
|
337
|
+
self.idx_grd = idx_grd
|
|
329
338
|
self.EVAL_FREQUENCY = EVAL_FREQUENCY
|
|
330
339
|
self.MESSAGE = MESSAGE
|
|
331
340
|
self.SHOWGPU = SHOWGPU
|