foscat 2025.7.2__py3-none-any.whl → 2025.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,546 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import healpy as hp
4
+ from scipy.sparse import csr_array
5
+ import torch
6
+ from scipy.spatial import cKDTree
7
+
8
+ class HOrientedConvol:
9
+ def __init__(self,nside,KERNELSZ,cell_ids=None,nest=True):
10
+
11
+ if KERNELSZ % 2 == 0:
12
+ raise ValueError(f"N must be odd so that coordinates are integers from -K..K; got N={KERNELSZ}.")
13
+
14
+ self.local_test=False
15
+
16
+ if cell_ids is None:
17
+ self.cell_ids=np.arange(12*nside**2)
18
+
19
+ idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
20
+ KERNELSZ*KERNELSZ,
21
+ nside,
22
+ nest=nest,
23
+ )
24
+ else:
25
+ try:
26
+ self.cell_ids=cell_ids.cpu().numpy()
27
+ except:
28
+ self.cell_ids=cell_ids
29
+
30
+ self.local_test=True
31
+
32
+ idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
33
+ KERNELSZ*KERNELSZ,
34
+ nside,
35
+ nest=nest,
36
+ )
37
+
38
+
39
+ mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids,nest=nest)
40
+
41
+ if self.local_test:
42
+ t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
43
+ else:
44
+ t,p = hp.pix2ang(nside,idx_nn,nest=True)
45
+
46
+ vec_orig=hp.ang2vec(t,p)
47
+
48
+ self.vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
49
+
50
+ '''
51
+ if self.local_test:
52
+ idx_nn=self.remap_by_first_column(idx_nn)
53
+ '''
54
+
55
+ del mat_pt
56
+ del vec_orig
57
+ self.t=t[:,0]
58
+ self.p=p[:,0]
59
+ self.idx_nn=idx_nn
60
+ self.nside=nside
61
+ self.KERNELSZ=KERNELSZ
62
+
63
+ def remap_by_first_column(self,idx: np.ndarray) -> np.ndarray:
64
+ """
65
+ Remap the values in `idx` so that:
66
+ - The first column becomes [0, 1, ..., N-1]
67
+ - All other columns are updated accordingly using the same mapping.
68
+
69
+ Parameters
70
+ ----------
71
+ idx : np.ndarray
72
+ Integer array of shape (N, m).
73
+ Assumes all values in idx are present in the first column (otherwise they get -1).
74
+
75
+ Returns
76
+ -------
77
+ np.ndarray
78
+ New array with remapped indices.
79
+ """
80
+ if idx.ndim != 2:
81
+ raise ValueError("idx must be a 2D array of shape (N, m)")
82
+
83
+ N, m = idx.shape
84
+
85
+ # Create a mapping: original_value_in_first_column -> row_index
86
+ # Example: if idx[:,0] = [101, 505, 303], then mapping = {101:0, 505:1, 303:2}
87
+ keys = idx[:, 0]
88
+ mapping = {v: i for i, v in enumerate(keys)}
89
+
90
+ # Optional check: ensure all values are in the mapping keys
91
+ # If not, you can raise an error or handle it differently
92
+ # if not np.isin(idx, keys).all():
93
+ # missing = np.unique(idx[~np.isin(idx, keys)])
94
+ # raise ValueError(f"Some values are not in idx[:,0]: {missing}")
95
+
96
+ # Function to get mapped value, or -1 if value is not found
97
+ get = mapping.get
98
+
99
+ # Apply mapping to all elements (vectorized via np.vectorize)
100
+ out = np.vectorize(lambda v: get(int(v), -1), otypes=[int])(idx)
101
+
102
+ return out
103
+
104
+ def rotation_matrices_from_healpix(self,nside, hpix_idx, nest=True):
105
+ """
106
+ Compute rotation matrices that move each Healpix pixel center to the North pole.
107
+ equivalent to rotation matrices R_z(phi) * R_y(-thi) for N points.
108
+
109
+ Parameters
110
+ ----------
111
+ nside : int
112
+ Healpix Nside resolution.
113
+ hpix_idx : array_like, shape (N,)
114
+ Healpix pixel indices.
115
+ nest : bool, optional
116
+ True if indices are in NESTED ordering, False for RING ordering.
117
+
118
+ Returns
119
+ -------
120
+ R : ndarray, shape (3, 3, N)
121
+ Rotation matrices for each pixel index.
122
+ """
123
+
124
+ try:
125
+ hpix_idx = np.asarray(hpix_idx)
126
+ except:
127
+ hpix_idx = hpix_idx.cpu().numpy()
128
+
129
+ N = hpix_idx.shape[0]
130
+
131
+ # Get angular coordinates of each pixel center
132
+ theta, phi = hp.pix2ang(nside, hpix_idx, nest=nest) # theta: colatitude (0=north pole)
133
+
134
+ # Precompute sines/cosines
135
+ cphi = np.cos(phi)
136
+ sphi = np.sin(phi)
137
+ cthi = np.cos(-theta)
138
+ sthi = np.sin(-theta)
139
+
140
+ # Rotation around Z (by phi)
141
+ Rz = np.zeros((3, 3, N))
142
+ Rz[0, 0, :] = cphi
143
+ Rz[0, 1, :] = -sphi
144
+ Rz[1, 0, :] = sphi
145
+ Rz[1, 1, :] = cphi
146
+ Rz[2, 2, :] = 1.0
147
+
148
+ # Rotation around Y (by -theta)
149
+ Ry = np.zeros((3, 3, N))
150
+ Ry[0, 0, :] = cthi
151
+ Ry[0, 2, :] = -sthi
152
+ Ry[1, 1, :] = 1.0
153
+ Ry[2, 0, :] = sthi
154
+ Ry[2, 2, :] = cthi
155
+
156
+ # Multiply Rz * Ry for each pixel
157
+ R = np.einsum('ijk,jlk->ilk', Rz, Ry)
158
+
159
+ return R
160
+
161
+ def _choose_depth_for_candidates(self, N, overshoot=2, max_depth=12):
162
+ """
163
+ Pick hierarchy depth d so that ~ 9 * 4**d >= overshoot * N.
164
+ Depth 0 => 9 candidates; 1 => 36; 2 => 144; 3 => 576; 4 => 2304; etc.
165
+ """
166
+ d = 0
167
+ while 9 * (4 ** d) < overshoot * N and d < max_depth:
168
+ d += 1
169
+ return d
170
+
171
+ def knn_healpix_ckdtree(self,
172
+ hidx, N, nside, *, nest=True,
173
+ include_self=True,
174
+ vec_dtype=np.float32,
175
+ out_dtype=np.int64
176
+ ):
177
+ """
178
+ k-NN using a cKDTree on unit vectors (exact in Euclidean space).
179
+ Returns LOCAL indices (0..M-1) of the N nearest neighbours per row.
180
+ """
181
+ try:
182
+ hidx = np.asarray(hidx, dtype=np.int64)
183
+ except:
184
+ hidx = hidx.cpu().numpy()
185
+
186
+ if hidx.ndim != 1:
187
+ raise ValueError("hidx must be 1D")
188
+ M = hidx.size
189
+ if M == 0:
190
+ return np.empty((0, 0), dtype=out_dtype)
191
+ if N <= 0:
192
+ raise ValueError("N must be >= 1")
193
+
194
+ # Effective N
195
+ N_eff = min(N, M if include_self else max(M-1, 1))
196
+
197
+ # Build unit vectors
198
+ hidx_n = hidx if nest else hp.ring2nest(nside, hidx)
199
+ x, y, z = hp.pix2vec(nside, hidx_n, nest=True)
200
+ V = np.stack([x, y, z], axis=1).astype(vec_dtype, copy=False) # (M,3)
201
+
202
+ tree = cKDTree(V)
203
+
204
+ if include_self:
205
+ # Self appears with distance 0 as the first neighbour
206
+ d, idx = tree.query(V, k=N_eff, workers=-1) # idx shape (M,N)
207
+ return idx.astype(out_dtype, copy=False)
208
+ else:
209
+ # Ask for one extra and drop self
210
+ k = min(N_eff + 1, M)
211
+ d, idx = tree.query(V, k=k, workers=-1)
212
+ # idx can be (M,) if k==1; normalize shapes
213
+ if idx.ndim == 1:
214
+ idx = idx[:, None]
215
+ # Remove self if present (distance 0)
216
+ out = np.empty((M, N_eff), dtype=out_dtype)
217
+ for i in range(M):
218
+ row = idx[i]
219
+ # filter out self (i); keep first N_eff
220
+ row = row[row != i][:N_eff]
221
+ # if M==N and no self, row already size N_eff
222
+ out[i, :row.size] = row
223
+ if row.size < N_eff:
224
+ # extremely rare (degenerate duplicates); fallback by scores
225
+ cand = np.setdiff1d(np.arange(M), np.r_[i, row], assume_unique=False)
226
+ # pick nearest remaining
227
+ di, ci = tree.query(V[i], k=N_eff - row.size)
228
+ out[i, row.size:] = np.atleast_1d(ci).astype(out_dtype, copy=False)
229
+ return out
230
+
231
+ def make_wavelet_matrix(self,
232
+ orientations,
233
+ polar=True,
234
+ norm_mean=True,
235
+ norm_std=True,
236
+ return_index=False,
237
+ return_smooth=False,
238
+ ):
239
+
240
+ sigma_gauss = 0.5
241
+ sigma_cosine = 0.5
242
+ if self.KERNELSZ == 3:
243
+ sigma_gauss = 1.0 / np.sqrt(2)
244
+ sigma_cosine = 1.0
245
+
246
+ orientations=np.asarray(orientations)
247
+ NORIENT = orientations.shape[0]
248
+
249
+ rotate=2*((self.t<np.pi/2)-0.5)[None,:,None]
250
+ if polar:
251
+ xx=np.cos(self.p[None,:]+np.pi/2-orientations[:,None])[:,:,None]*self.vec_rot[None,:,:,0]-rotate*np.sin(self.p[None,:]+np.pi/2-orientations[:,None])[:,:,None]*self.vec_rot[None,:,:,1]
252
+ else:
253
+ xx=np.cos(np.pi/2-orientations[:,None,None])*self.vec_rot[None,:,:,0]-np.sin(np.pi/2-orientations[:,None,None])*self.vec_rot[None,:,:,1]
254
+
255
+ r=(self.vec_rot[None,:,:,0]**2+self.vec_rot[None,:,:,1]**2+(self.vec_rot[None,:,:,2]-1.0)**2)
256
+
257
+ if return_smooth:
258
+ wsmooth=np.exp(-sigma_gauss*r*self.nside**2)
259
+ if norm_std:
260
+ ww=np.sum(wsmooth,2)
261
+ wsmooth = wsmooth/ww[:,:,None]
262
+
263
+ #for consistency with previous definition
264
+ w=np.exp(-sigma_gauss*r*self.nside**2)*(np.cos(xx*self.nside*sigma_cosine*np.pi)-1J*np.sin(xx*self.nside*sigma_cosine*np.pi))
265
+
266
+ if norm_std:
267
+ ww=1/np.sum(abs(w),2)[:,:,None]
268
+ else:
269
+ ww=1.0
270
+
271
+ if norm_mean:
272
+ w = (w.real-np.mean(w.real,2)[:,:,None]+1J*(w.imag-np.mean(w.imag,2)[:,:,None]))*ww
273
+
274
+ NK=self.idx_nn.shape[1]
275
+ indice_1_0 = np.tile(self.idx_nn.flatten(),NORIENT)
276
+ indice_1_1 = np.tile(np.repeat(self.idx_nn[:,0],NK),NORIENT)+ \
277
+ np.repeat(np.arange(NORIENT),self.idx_nn.shape[0]*self.idx_nn.shape[1])*self.idx_nn.shape[0]
278
+ w = w.flatten()
279
+
280
+ if return_smooth:
281
+ indice_2_0 = self.idx_nn.flatten()
282
+ indice_2_1 = np.repeat(self.idx_nn[:,0],NK)
283
+ wsmooth = wsmooth.flatten()
284
+
285
+ if return_index:
286
+ if return_smooth:
287
+ return w,np.concatenate([indice_1_0[:,None],indice_1_1[:,None]],1),wsmooth,np.concatenate([indice_2_0[:,None],indice_2_1[:,None]],1)
288
+
289
+ return w,np.concatenate([indice_1_0[:,None],indice_1_1[:,None]],1)
290
+
291
+ return csr_array((w, (indice_1_0, indice_1_1)), shape=(12*self.nside**2, 12*self.nside**2*NORIENT))
292
+
293
+
294
+ def make_idx_weights(self,polar=False,gamma=1.0,device='cuda',allow_extrapolation=True):
295
+
296
+ rotate=2*((self.t<np.pi/2)-0.5)[:,None]
297
+ if polar:
298
+ xx=np.cos(self.p)[:,None]*self.vec_rot[:,:,0]-rotate*np.sin(self.p)[:,None]*self.vec_rot[:,:,1]
299
+ yy=-np.sin(self.p)[:,None]*self.vec_rot[:,:,0]-rotate*np.cos(self.p)[:,None]*self.vec_rot[:,:,1]
300
+ else:
301
+ xx=self.vec_rot[:,:,0]
302
+ yy=self.vec_rot[:,:,1]
303
+
304
+ self.w_idx,self.w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
305
+ yy*self.nside*gamma,
306
+ allow_extrapolation=allow_extrapolation)
307
+
308
+ # Ensure types/devices
309
+ self.idx_nn = torch.Tensor(self.idx_nn).to(device=device, dtype=torch.long)
310
+ self.w_idx = torch.Tensor(self.w_idx).to(device=device, dtype=torch.long)
311
+ self.w_w = torch.Tensor(self.w_w).to(device=device, dtype=torch.float64)
312
+
313
+ def _grid_index(self, xi, yi):
314
+ """
315
+ Map integer grid coords (xi, yi) in {-1,0,1} to flat index in [0..8]
316
+ following the given order (row-major from y=-1 to y=1).
317
+ """
318
+ return (yi + self.KERNELSZ//2) * self.KERNELSZ + (xi + self.KERNELSZ//2)
319
+
320
+ def bilinear_weights_NxN(self,x, y, allow_extrapolation=True):
321
+ """
322
+ Compute bilinear weights on an N×N integer grid with node coordinates
323
+ (xi, yi) in {-K, ..., +K} × {-K, ..., +K}, where K = N//2 (N must be odd).
324
+
325
+ N is attached to the class `N = self.KERNELSZ`
326
+
327
+ The query point (x, y) is continuous in the same coordinate system.
328
+ For each query, we pick the unit cell [x0, x0+1] × [y0, y0+1] with
329
+ integer corners (x0,y0), (x0+1,y0), (x0,y0+1), (x0+1,y0+1), and compute
330
+ standard bilinear weights relative to (x0, y0).
331
+
332
+ Parameters
333
+ ----------
334
+ x, y : float or array-like of shape (M,)
335
+ Query coordinates in the integer grid coordinate system.
336
+ N : int
337
+ Grid size (must be odd). Grid nodes are at integer coords
338
+ xi, yi ∈ {-K, ..., +K}, where K = N//2.
339
+ allow_extrapolation : bool, default True
340
+ - If False: clamp (x, y) to [-K, +K] so that tx, ty ∈ [0, 1] and
341
+ weights are non-negative and sum to 1.
342
+ - If True : do not clamp (x, y); we still select the nearest boundary
343
+ cell inside the grid for the indices, but tx, ty may fall outside
344
+ [0, 1], yielding extrapolation (weights can be negative).
345
+
346
+ Returns
347
+ -------
348
+ idx : ndarray of shape (M, 4), dtype=int64
349
+ Flat indices (0 .. N*N-1) of the four cell-corner nodes in row-major
350
+ order (y from -K to +K, x from -K to +K):
351
+ order = [(x0,y0), (x0+1,y0), (x0,y0+1), (x0+1,y0+1)].
352
+ w : ndarray of shape (M, 4), dtype=float64
353
+ Corresponding bilinear weights for each query point. If
354
+ allow_extrapolation=False and the point is inside the grid, each row
355
+ sums to 1 and all weights are in [0,1].
356
+
357
+ Notes
358
+ -----
359
+ - This matches your previous 3×3 case when N=3, with the same row-major
360
+ flattening convention.
361
+ - For extrapolation=True, indices are kept in-bounds (clamped to boundary
362
+ cells), while tx, ty > 1 or < 0 are allowed.
363
+ """
364
+ # --- checks & shapes ---
365
+ N=self.KERNELSZ
366
+
367
+ K = N // 2
368
+
369
+ x = np.atleast_1d(np.asarray(x, dtype=float))
370
+ y = np.atleast_1d(np.asarray(y, dtype=float))
371
+ if x.shape != y.shape:
372
+ raise ValueError("x and y must have the same shape")
373
+ M = x.shape[0]
374
+
375
+ # --- optionally clamp queries (for pure interpolation) ---
376
+ if not allow_extrapolation:
377
+ x = np.clip(x, -K, K)
378
+ y = np.clip(y, -K, K)
379
+
380
+ # --- choose the cell: x0=floor(x), y0=floor(y), but keep indices in-bounds
381
+ # cell must be inside [-K..K-1] × [-K..K-1] so that +1 is valid
382
+ x0 = np.floor(x)
383
+ y0 = np.floor(y)
384
+ x0 = np.clip(x0, -K, K - 1).astype(int)
385
+ y0 = np.clip(y0, -K, K - 1).astype(int)
386
+ x1 = x0 + 1
387
+ y1 = y0 + 1
388
+
389
+ # --- local coords within the cell (unit spacing) ---
390
+ tx = x - x0
391
+ ty = y - y0
392
+
393
+ # --- bilinear weights ---
394
+ # (x0,y0) w00, (x1,y0) w10, (x0,y1) w01, (x1,y1) w11
395
+ w00 = (1.0 - tx) * (1.0 - ty)
396
+ w10 = tx * (1.0 - ty)
397
+ w01 = (1.0 - tx) * ty
398
+ w11 = tx * ty
399
+ w = np.stack([w00, w10, w01, w11], axis=1)
400
+
401
+ # --- flat indices in row-major order (y changes slowest) ---
402
+ # index = (yi + K) * N + (xi + K)
403
+ def flat_idx(xi, yi):
404
+ return (yi + K) * N + (xi + K)
405
+
406
+ i00 = flat_idx(x0, y0)
407
+ i10 = flat_idx(x1, y0)
408
+ i01 = flat_idx(x0, y1)
409
+ i11 = flat_idx(x1, y1)
410
+ idx = np.stack([i00, i10, i01, i11], axis=1).astype(np.int64)
411
+
412
+ return idx, w
413
+
414
+ def Convol_torch(self, im, ww):
415
+ """
416
+ Batched KERNELSZxKERNELSZ neighborhood aggregation in pure PyTorch (generalization of the 3x3 case).
417
+
418
+ Parameters
419
+ ----------
420
+ im : Tensor, shape (B, C_i, Npix)
421
+ Input features per pixel for a batch of B samples.
422
+ ww : Tensor
423
+ Base mixing weights, indexed along its 'M' dimension by self.w_idx.
424
+ Supported shapes:
425
+ (C_i, C_o, M)
426
+ (C_i, C_o, M, S)
427
+ (B, C_i, C_o, M)
428
+ (B, C_i, C_o, M, S)
429
+
430
+ Class members (already tensors; will be aligned to im.device/dtype):
431
+ -------------------------------------------------------------------
432
+ self.idx_nn : LongTensor, shape (Npix, P)
433
+ For each center pixel, the P neighbor indices into the Npix axis of `im`.
434
+ (P = K*K for a KxK neighborhood.)
435
+ self.w_idx : LongTensor, shape (Npix, P) or (Npix, S, P)
436
+ Indices along the 'M' dimension of ww, per (center[, sector], neighbor).
437
+ self.w_w : Tensor, shape (Npix, P) or (Npix, S, P)
438
+ Additional scalar weights per neighbor (same layout as w_idx).
439
+
440
+ Returns
441
+ -------
442
+ out : Tensor, shape (B, C_o, Npix)
443
+ Aggregated output per center pixel for each batch sample.
444
+ """
445
+ # ---- Basic checks ----
446
+ assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
447
+ assert ww.shape[2]==self.KERNELSZ*self.KERNELSZ, f"`ww` must be (C_i, C_o, KERNELSZ*KERNELSZ), got {tuple(ww.shape)}"
448
+
449
+ B, C_i, Npix = im.shape
450
+ device = im.device
451
+ dtype = im.dtype
452
+
453
+ # Align class tensors to device/dtype
454
+ idx_nn = self.idx_nn.to(device=device, dtype=torch.long) # (Npix, P)
455
+ w_idx = self.w_idx.to(device=device, dtype=torch.long) # (Npix, P) or (Npix, S, P)
456
+ w_w = self.w_w.to(device=device, dtype=dtype) # (Npix, P) or (Npix, S, P)
457
+
458
+ # Neighbor count P inferred from idx_nn
459
+ assert idx_nn.ndim == 2 and idx_nn.size(0) == Npix, \
460
+ f"`idx_nn` must be (Npix, P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
461
+ P = idx_nn.size(1)
462
+
463
+ # ---- 1) Gather neighbor values from im along the Npix dimension -> (B, C_i, Npix, P)
464
+ # im: (B,C_i,Npix) -> (B,C_i,Npix,1); idx: (1,1,Npix,P) broadcast over (B,C_i)
465
+ rim = torch.take_along_dim(
466
+ im.unsqueeze(-1),
467
+ idx_nn.unsqueeze(0).unsqueeze(0),
468
+ dim=2
469
+ ) # (B, C_i, Npix, P)
470
+
471
+ # ---- 2) Normalize w_idx / w_w to include a sector dim S ----
472
+ # Target layout: (Npix, S, P)
473
+ if w_idx.ndim == 2:
474
+ # (Npix, P) -> add sector dim S=1
475
+ assert w_idx.size(0) == Npix and w_idx.size(1) == P
476
+ w_idx_eff = w_idx.unsqueeze(1) # (Npix, 1, P)
477
+ w_w_eff = w_w.unsqueeze(1) # (Npix, 1, P)
478
+ S = 1
479
+ elif w_idx.ndim == 3:
480
+ # (Npix, S, P)
481
+ Npix_, S, P_ = w_idx.shape
482
+ assert Npix_ == Npix and P_ == P, \
483
+ f"`w_idx` must be (Npix,S,P) with Npix={Npix}, P={P}, got {tuple(w_idx.shape)}"
484
+ assert w_w.shape == w_idx.shape, "`w_w` must match `w_idx` shape"
485
+ w_idx_eff = w_idx
486
+ w_w_eff = w_w
487
+ else:
488
+ raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
489
+
490
+ # ---- 3) Normalize ww to (B, C_i, C_o, M, S) for uniform gather ----
491
+ if ww.ndim == 3:
492
+ # (C_i, C_o, M) -> (B, C_i, C_o, M, S)
493
+ C_i_w, C_o, M = ww.shape
494
+ assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
495
+ ww_eff = ww.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, -1, S)
496
+
497
+ elif ww.ndim == 4:
498
+ # Could be (C_i, C_o, M, S) or (B, C_i, C_o, M)
499
+ if ww.shape[0] == C_i and ww.shape[1] != C_i:
500
+ # (C_i, C_o, M, S) -> (B, C_i, C_o, M, S)
501
+ C_i_w, C_o, M, S_w = ww.shape
502
+ assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
503
+ assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
504
+ ww_eff = ww.unsqueeze(0).expand(B, -1, -1, -1, -1)
505
+ elif ww.shape[0] == B:
506
+ # (B, C_i, C_o, M) -> (B, C_i, C_o, M, S)
507
+ _, C_i_w, C_o, M = ww.shape
508
+ assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
509
+ ww_eff = ww.unsqueeze(-1).expand(-1, -1, -1, -1, S)
510
+ else:
511
+ raise ValueError(
512
+ f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)"
513
+ )
514
+
515
+ elif ww.ndim == 5:
516
+ # (B, C_i, C_o, M, S)
517
+ assert ww.shape[0] == B and ww.shape[1] == C_i, "ww batch/C_i mismatch"
518
+ _, _, _, M, S_w = ww.shape
519
+ assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
520
+ ww_eff = ww
521
+ else:
522
+ raise ValueError(f"Unsupported ww shape {tuple(ww.shape)}")
523
+
524
+ # ---- 4) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
525
+ idx_exp = w_idx_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1,1,1,Npix,S,P)
526
+ rw = torch.take_along_dim(
527
+ ww_eff.unsqueeze(-1), # (B, C_i, C_o, M, S, 1)
528
+ idx_exp, # (1,1,1,Npix,S,P) -> broadcast
529
+ dim=3 # gather along M
530
+ ) # -> (B, C_i, C_o, Npix, S, P)
531
+
532
+ # ---- 5) Apply extra neighbor weights ----
533
+ rw = rw * w_w_eff.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (B, C_i, C_o, Npix, S, P)
534
+
535
+ # ---- 6) Combine neighbor values and weights ----
536
+ # rim: (B, C_i, Npix, P) -> expand to (B, C_i, 1, Npix, 1, P)
537
+ rim_exp = rim[:, :, None, :, None, :]
538
+ # sum over neighbors (P), then over sectors (S), then over input channels (C_i)
539
+ out_ci = (rim_exp * rw).sum(dim=-1) # (B, C_i, C_o, Npix, S)
540
+ out_ci = out_ci.sum(dim=-1) # (B, C_i, C_o, Npix)
541
+ out = out_ci.sum(dim=1) # (B, C_o, Npix)
542
+
543
+ return out
544
+
545
+
546
+
foscat/HealSpline.py CHANGED
@@ -51,7 +51,7 @@ class heal_spline:
51
51
  kind='cubic', fill_value='extrapolate')
52
52
 
53
53
 
54
- def ang2weigths(self,th,ph,threshold=1E-2,nest=True):
54
+ def ang2weigths(self,th,ph,threshold=1E-2,nest=False):
55
55
  th0=self.f_interp_th(th).flatten()
56
56
 
57
57
  idx_lat,w_th=self.spline_lat.eval(th0.flatten())
@@ -73,11 +73,13 @@ class heal_spline:
73
73
  www=www.reshape(16,www.shape[2])
74
74
  all_idx=all_idx.reshape(16,all_idx.shape[2])
75
75
 
76
+ if nest:
77
+ all_idx = hp.ring2nest(self.nside,all_idx)
78
+
76
79
  heal_idx,inv_idx = np.unique(all_idx,
77
80
  return_inverse=True)
78
81
  all_idx = inv_idx
79
- if nest:
80
- heal_idx = hp.ring2nest(self.nside,heal_idx)
82
+
81
83
  self.cell_ids = heal_idx
82
84
 
83
85
  hit=np.bincount(all_idx.flatten(),weights=www.flatten())
@@ -190,8 +192,9 @@ class heal_spline:
190
192
  spl=np.zeros([scale])
191
193
  spl[heal_idx[ih==k]-scale*h[k]]=self.spline[ih==k]
192
194
  self.spline_tree[h[k]]=spl
193
-
194
-
195
+
196
+ def GetParam(self):
197
+ return self.heal_idx,self.spline
195
198
 
196
199
  def Transform(self,th,ph,threshold=1E-2,nest=True):
197
200
 
foscat/Synthesis.py CHANGED
@@ -195,6 +195,8 @@ class Synthesis:
195
195
  x = self.operation.backend.bk_reshape(
196
196
  self.operation.backend.bk_cast(in_x), self.oshape
197
197
  )
198
+ if self.idx_grd is not None:
199
+ x=x[self.idx_grd]
198
200
 
199
201
  self.l_log[
200
202
  self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
@@ -246,6 +248,11 @@ class Synthesis:
246
248
 
247
249
  g_tot[np.isnan(g_tot)] = 0.0
248
250
 
251
+ if self.idx_grd is not None:
252
+ lg_tot=np.zeros(in_x.shape)
253
+ lg_tot[self.idx_grd]=g_tot
254
+ g_tot=lg_tot
255
+
249
256
  self.imin = self.imin + self.batchsz
250
257
 
251
258
  if self.mpi_size == 1:
@@ -295,24 +302,25 @@ class Synthesis:
295
302
 
296
303
  # ---------------------------------------------−---------
297
304
  def run(
298
- self,
299
- in_x,
300
- NUM_EPOCHS=100,
301
- DECAY_RATE=0.95,
302
- EVAL_FREQUENCY=100,
303
- DEVAL_STAT_FREQUENCY=1000,
304
- NUM_STEP_BIAS=1,
305
- LEARNING_RATE=0.03,
306
- EPSILON=1e-7,
307
- KEEP_TRACK=None,
308
- grd_mask=None,
309
- SHOWGPU=False,
310
- MESSAGE="",
311
- factr=10.0,
312
- batchsz=1,
313
- totalsz=1,
314
- do_lbfgs=True,
315
- axis=0,
305
+ self,
306
+ in_x,
307
+ NUM_EPOCHS=100,
308
+ DECAY_RATE=0.95,
309
+ EVAL_FREQUENCY=100,
310
+ DEVAL_STAT_FREQUENCY=1000,
311
+ NUM_STEP_BIAS=1,
312
+ LEARNING_RATE=0.03,
313
+ EPSILON=1e-7,
314
+ KEEP_TRACK=None,
315
+ grd_mask=None,
316
+ SHOWGPU=False,
317
+ MESSAGE="",
318
+ factr=10.0,
319
+ batchsz=1,
320
+ totalsz=1,
321
+ do_lbfgs=True,
322
+ idx_grd=None,
323
+ axis=0,
316
324
  ):
317
325
 
318
326
  self.KEEP_TRACK = KEEP_TRACK
@@ -326,6 +334,7 @@ class Synthesis:
326
334
  self.batchsz = batchsz
327
335
  self.totalsz = totalsz
328
336
  self.grd_mask = grd_mask
337
+ self.idx_grd = idx_grd
329
338
  self.EVAL_FREQUENCY = EVAL_FREQUENCY
330
339
  self.MESSAGE = MESSAGE
331
340
  self.SHOWGPU = SHOWGPU