foscat 2025.8.4__py3-none-any.whl → 2025.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +241 -49
- foscat/FoCUS.py +1 -1
- foscat/HOrientedConvol.py +446 -42
- foscat/HealBili.py +305 -0
- foscat/Plot.py +328 -0
- foscat/UNET.py +455 -178
- foscat/healpix_unet_torch.py +717 -0
- foscat/scat_cov.py +1 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/METADATA +1 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/RECORD +13 -10
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/WHEEL +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/top_level.txt +0 -0
foscat/BkTorch.py
CHANGED
|
@@ -63,70 +63,262 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
63
63
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
64
64
|
)
|
|
65
65
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
66
|
+
# ---------------------------------
|
|
67
|
+
# HEALPix binning utilities (nested)
|
|
68
|
+
# ---------------------------------
|
|
69
|
+
# Robust binned_mean that supports arbitrary subsets (N not divisible by 4)
|
|
70
|
+
# and batched cell_ids of shape [B, N]. It returns compact per-parent means
|
|
71
|
+
# even when some parents are missing (sparse coverage).
|
|
72
|
+
|
|
73
|
+
def binned_mean(self, data, cell_ids, *, padded: bool = False, fill_value: float = float("nan")):
|
|
74
|
+
"""Average values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
|
|
75
|
+
|
|
76
|
+
Works with full-sky or sparse subsets (no need for N to be divisible by 4).
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
data : torch.Tensor or np.ndarray
|
|
81
|
+
Shape ``[..., N]`` or ``[B, ..., N]``.
|
|
82
|
+
cell_ids : torch.LongTensor or np.ndarray
|
|
83
|
+
Shape ``[N]`` or ``[B, N]`` (nested indexing at the *child* resolution).
|
|
84
|
+
padded : bool, optional (default: False)
|
|
85
|
+
Only used when ``cell_ids`` is ``[B, N]``. If ``False``, returns Python
|
|
86
|
+
lists (ragged) of per-batch results. If ``True``, returns padded tensors
|
|
87
|
+
plus a boolean mask of valid bins.
|
|
88
|
+
fill_value : float, optional
|
|
89
|
+
Value used for padding when ``padded=True``.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
If ``cell_ids`` is ``[N]``:
|
|
94
|
+
mean : torch.Tensor, shape ``[..., n_bins]``
|
|
95
|
+
groups: torch.LongTensor, shape ``[n_bins]`` (sorted unique parents)
|
|
96
|
+
|
|
97
|
+
If ``cell_ids`` is ``[B, N]`` and ``padded=False``:
|
|
98
|
+
means_list : List[torch.Tensor] of length B, each shape ``[T, n_bins_b]``
|
|
99
|
+
where ``T = prod(data.shape[1:-1])`` (or 1 if none).
|
|
100
|
+
groups_list : List[torch.LongTensor] of length B, each shape ``[n_bins_b]``
|
|
101
|
+
|
|
102
|
+
If ``cell_ids`` is ``[B, N]`` and ``padded=True``:
|
|
103
|
+
mean_padded : torch.Tensor, shape ``[B, T, max_bins]`` (or ``[B, max_bins]`` if T==1)
|
|
104
|
+
groups_pad : torch.LongTensor, shape ``[B, max_bins]`` (parents, padded with -1)
|
|
105
|
+
mask : torch.BoolTensor, shape ``[B, max_bins]`` (True where valid)
|
|
69
106
|
"""
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
Args:
|
|
73
|
-
data (torch.Tensor): Tensor of shape [..., N], where N is the number of HEALPix cells.
|
|
74
|
-
cell_ids (torch.LongTensor): Tensor of shape [N], with cell indices (nested ordering).
|
|
107
|
+
import torch, numpy as np
|
|
75
108
|
|
|
76
|
-
|
|
77
|
-
torch.Tensor: Tensor of shape [..., n_bins], with averaged values per group of 4 cells.
|
|
78
|
-
"""
|
|
109
|
+
# ---- Tensorize & device/dtype plumbing ----
|
|
79
110
|
if isinstance(data, np.ndarray):
|
|
80
|
-
data = torch.from_numpy(data).to(
|
|
81
|
-
dtype=torch.float32, device=self.torch_device
|
|
82
|
-
)
|
|
111
|
+
data = torch.from_numpy(data).to(dtype=torch.float32, device=getattr(self, 'torch_device', 'cpu'))
|
|
83
112
|
if isinstance(cell_ids, np.ndarray):
|
|
84
|
-
cell_ids = torch.from_numpy(cell_ids).to(
|
|
85
|
-
dtype=torch.long, device=self.torch_device
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
# Compute supercell ids by grouping 4 nested cells together
|
|
89
|
-
groups = cell_ids // 4
|
|
113
|
+
cell_ids = torch.from_numpy(cell_ids).to(dtype=torch.long, device=data.device)
|
|
90
114
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
n_bins = unique_groups.shape[0]
|
|
115
|
+
data = data.to(device=getattr(self, 'torch_device', data.device))
|
|
116
|
+
cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
|
|
94
117
|
|
|
95
|
-
|
|
96
|
-
|
|
118
|
+
if data.ndim < 1:
|
|
119
|
+
raise ValueError("`data` must have at least 1 dimension (last is N).")
|
|
97
120
|
N = data.shape[-1]
|
|
98
|
-
data_flat = data.reshape(-1, N) # Shape: [B, N]
|
|
99
121
|
|
|
100
|
-
#
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
122
|
+
# Flatten leading dims (rows) for scatter convenience
|
|
123
|
+
orig = data.shape[:-1]
|
|
124
|
+
T = int(np.prod(orig[1:])) if len(orig) > 1 else 1 # repeats per batch row
|
|
125
|
+
if cell_ids.ndim == 1:
|
|
126
|
+
# Shared mapping for all rows
|
|
127
|
+
groups = (cell_ids // 4).to(torch.long) # [N]
|
|
128
|
+
# Unique parent ids + inverse indices
|
|
129
|
+
parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
|
|
130
|
+
n_bins = parents.numel()
|
|
131
|
+
|
|
132
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
133
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
134
|
+
|
|
135
|
+
# Row offsets -> independent bins per row
|
|
136
|
+
row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins # [R,1]
|
|
137
|
+
idx = inv.unsqueeze(0).expand(R, -1) + row_offsets # [R,N]
|
|
138
|
+
|
|
139
|
+
vals_flat = data_flat.reshape(-1)
|
|
140
|
+
idx_flat = idx.reshape(-1)
|
|
141
|
+
|
|
142
|
+
out_sum = torch.zeros(R * n_bins, dtype=data.dtype, device=data.device)
|
|
143
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
144
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
145
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
146
|
+
out_cnt.clamp_(min=1)
|
|
147
|
+
|
|
148
|
+
mean = (out_sum / out_cnt).view(*orig, n_bins)
|
|
149
|
+
return mean, parents
|
|
150
|
+
|
|
151
|
+
elif cell_ids.ndim == 2:
|
|
152
|
+
B = cell_ids.shape[0]
|
|
153
|
+
if data.shape[0] % B != 0:
|
|
154
|
+
raise ValueError(f"Leading dim of data ({data.shape[0]}) must be a multiple of cell_ids batch ({B}).")
|
|
155
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
156
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
157
|
+
B_data = data.shape[0]
|
|
158
|
+
T = R // B_data # repeats per batch row (product of extra leading dims)
|
|
159
|
+
|
|
160
|
+
means_list, groups_list = [], []
|
|
161
|
+
max_bins = 0
|
|
162
|
+
# First pass: compute per-batch parents/inv and scatter means
|
|
163
|
+
for b in range(B):
|
|
164
|
+
groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
|
|
165
|
+
parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
|
|
166
|
+
n_bins_b = parents_b.numel()
|
|
167
|
+
max_bins = max(max_bins, n_bins_b)
|
|
168
|
+
|
|
169
|
+
# rows for this batch in data_flat
|
|
170
|
+
start = b * T
|
|
171
|
+
stop = (b + 1) * T
|
|
172
|
+
rows = slice(start, stop) # T rows
|
|
173
|
+
|
|
174
|
+
row_offsets = (torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b)
|
|
175
|
+
idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
|
|
176
|
+
|
|
177
|
+
vals_flat = data_flat[rows].reshape(-1)
|
|
178
|
+
idx_flat = idx.reshape(-1)
|
|
179
|
+
|
|
180
|
+
out_sum = torch.zeros(T * n_bins_b, dtype=data.dtype, device=data.device)
|
|
181
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
182
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
183
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
184
|
+
out_cnt.clamp_(min=1)
|
|
185
|
+
mean_bt = (out_sum / out_cnt).view(T, n_bins_b) # [T, n_bins_b]
|
|
186
|
+
|
|
187
|
+
means_list.append(mean_bt)
|
|
188
|
+
groups_list.append(parents_b)
|
|
189
|
+
|
|
190
|
+
if not padded:
|
|
191
|
+
return means_list, groups_list
|
|
192
|
+
|
|
193
|
+
# Padded output
|
|
194
|
+
# mean_padded: [B, T, max_bins]; groups_pad: [B, max_bins]; mask: [B, max_bins]
|
|
195
|
+
mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
|
|
196
|
+
groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
|
|
197
|
+
mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
|
|
198
|
+
for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
|
|
199
|
+
nb = g_b.numel()
|
|
200
|
+
mean_pad[b, :, :nb] = m_b
|
|
201
|
+
groups_pad[b, :nb] = g_b
|
|
202
|
+
mask[b, :nb] = True
|
|
203
|
+
|
|
204
|
+
# Reshape back to [B, (*extra leading dims), max_bins] if needed
|
|
205
|
+
if len(orig) > 1:
|
|
206
|
+
extra = orig[1:] # e.g., (D1, D2, ...)
|
|
207
|
+
mean_pad = mean_pad.view(B, *extra, max_bins)
|
|
208
|
+
else:
|
|
209
|
+
mean_pad = mean_pad.view(B, max_bins)
|
|
105
210
|
|
|
106
|
-
|
|
107
|
-
batch_offsets = torch.arange(B, device=data.device).unsqueeze(1) * n_bins
|
|
108
|
-
idx_offset = idx + batch_offsets # Shape: [B, N]
|
|
211
|
+
return mean_pad, groups_pad, mask
|
|
109
212
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
213
|
+
else:
|
|
214
|
+
raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
|
|
215
|
+
'''
|
|
216
|
+
def binned_mean(self, data, cell_ids):
|
|
217
|
+
"""
|
|
218
|
+
Moyenne par groupes de 4 pixels HEALPix nested (nside -> nside/2),
|
|
219
|
+
fonctionne avec un sous-ensemble arbitraire de pixels.
|
|
220
|
+
|
|
221
|
+
Args
|
|
222
|
+
----
|
|
223
|
+
data: torch.Tensor | np.ndarray, shape [..., N] ou [B, ..., N]
|
|
224
|
+
cell_ids: torch.LongTensor | np.ndarray, shape [N] ou [B, N] (nested)
|
|
225
|
+
|
|
226
|
+
Returns
|
|
227
|
+
-------
|
|
228
|
+
mean: torch.Tensor, shape [..., G] ou [B, ..., G]
|
|
229
|
+
groups_out: torch.LongTensor, shape [G] (ids HEALPix parents à nside/2)
|
|
230
|
+
"""
|
|
113
231
|
|
|
114
|
-
#
|
|
115
|
-
|
|
116
|
-
|
|
232
|
+
# --- to tensors on device ---
|
|
233
|
+
if isinstance(data, np.ndarray):
|
|
234
|
+
data = torch.from_numpy(data).to(dtype=torch.float32, device=self.torch_device)
|
|
235
|
+
if isinstance(cell_ids, np.ndarray):
|
|
236
|
+
cell_ids = torch.from_numpy(cell_ids).to(dtype=torch.long, device=self.torch_device)
|
|
237
|
+
data = data.to(self.torch_device)
|
|
238
|
+
cell_ids = cell_ids.to(self.torch_device, dtype=torch.long)
|
|
117
239
|
|
|
118
|
-
#
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
240
|
+
# --- shapes ---
|
|
241
|
+
if data.ndim < 1:
|
|
242
|
+
raise ValueError("`data` must have at least 1 dim; last is N.")
|
|
243
|
+
N = data.shape[-1]
|
|
244
|
+
if N % 4 != 0:
|
|
245
|
+
raise ValueError(f"N={N} must be divisible by 4 for nested groups of 4.")
|
|
246
|
+
|
|
247
|
+
# --- parent groups @ nside/2, accept [N] or [B,N] ---
|
|
248
|
+
if cell_ids.ndim == 1:
|
|
249
|
+
if cell_ids.shape[0] != N:
|
|
250
|
+
raise ValueError(f"cell_ids shape {tuple(cell_ids.shape)} incompatible with N={N}.")
|
|
251
|
+
groups_parent = (cell_ids // 4).long() # [N]
|
|
252
|
+
# densification -> [0..G-1]
|
|
253
|
+
unique_groups, inverse = torch.unique(groups_parent, return_inverse=True) # [G], [N]
|
|
254
|
+
# mapping identique pour toutes les lignes/rows de data
|
|
255
|
+
B_ids = 1
|
|
256
|
+
elif cell_ids.ndim == 2:
|
|
257
|
+
B_ids, N_ids = cell_ids.shape
|
|
258
|
+
if N_ids != N:
|
|
259
|
+
raise ValueError(f"cell_ids last dim {N_ids} must equal N={N}.")
|
|
260
|
+
# vérif compatibilité batch (data doit commencer par B ou par un multiple de B)
|
|
261
|
+
leading = data.shape[:-1]
|
|
262
|
+
if len(leading) == 0:
|
|
263
|
+
raise ValueError("`data` must have a leading dim to match [B, N] cell_ids.")
|
|
264
|
+
B_data = leading[0]
|
|
265
|
+
if B_data % B_ids != 0:
|
|
266
|
+
raise ValueError(f"Leading batch of data ({B_data}) must be a multiple of cell_ids batch ({B_ids}).")
|
|
267
|
+
|
|
268
|
+
# Construire un mapping DENSE par batch, mais on impose que la topologie
|
|
269
|
+
# des parents soit la même pour tous les batches -> on se base sur le batch 0
|
|
270
|
+
groups_parent0 = (cell_ids[0] // 4).long() # [N]
|
|
271
|
+
unique_groups, inverse0 = torch.unique(groups_parent0, return_inverse=True) # [G], [N]
|
|
272
|
+
|
|
273
|
+
# Vérification (optionnelle mais sûre) : chaque batch a les mêmes parents (ordre potentiellement différent OK)
|
|
274
|
+
# -> ici on exige même l'égalité stricte pour éviter les surprises ;
|
|
275
|
+
# sinon on pourrait densifier par-batch et retourner une liste de groups_out.
|
|
276
|
+
for b in range(1, B_ids):
|
|
277
|
+
if not torch.equal(groups_parent0, (cell_ids[b] // 4).long()):
|
|
278
|
+
raise ValueError("All batches in cell_ids must share the same parent groups (order & content).")
|
|
279
|
+
|
|
280
|
+
# Construire l'inverse pour tous les batches en répliquant celui du batch 0
|
|
281
|
+
inverse = inverse0.unsqueeze(0).expand(B_ids, -1) # [B_ids, N]
|
|
282
|
+
else:
|
|
283
|
+
raise ValueError("`cell_ids` must be [N] or [B, N].")
|
|
122
284
|
|
|
123
|
-
#
|
|
124
|
-
mean = out / counts # Shape: [B * n_bins]
|
|
125
|
-
mean = mean.view(B, n_bins)
|
|
285
|
+
G = unique_groups.numel() # nb de groupes parents (nside/2)
|
|
126
286
|
|
|
127
|
-
#
|
|
128
|
-
|
|
287
|
+
# --- aplatir data en lignes ---
|
|
288
|
+
original_shape = data.shape[:-1] # e.g. [B, D1, ...]
|
|
289
|
+
R = int(np.prod(original_shape)) if original_shape else 1
|
|
290
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
129
291
|
|
|
292
|
+
# --- construire indices de bins par ligne ---
|
|
293
|
+
if cell_ids.ndim == 1:
|
|
294
|
+
idx = inverse.expand(R, -1) # [R, N], dans [0..G-1]
|
|
295
|
+
else:
|
|
296
|
+
# cell_ids est [B_ids, N]; on doit « étirer » chaque ligne de mapping
|
|
297
|
+
# pour couvrir les R lignes de data_flat en respectant la 1ère dim (B_data)
|
|
298
|
+
B_data = original_shape[0]
|
|
299
|
+
T = R // B_data # répétitions par batch-row
|
|
300
|
+
idx = inverse.repeat_interleave(T, dim=0) # [B_ids*T, N] == [R, N]
|
|
301
|
+
|
|
302
|
+
# --- scatter add (somme et compte) par ligne ---
|
|
303
|
+
device = data.device
|
|
304
|
+
row_offsets = torch.arange(R, device=device).unsqueeze(1) * G
|
|
305
|
+
idx_offset = idx.to(torch.long) + row_offsets # [R,N]
|
|
306
|
+
idx_offset_flat = idx_offset.reshape(-1)
|
|
307
|
+
vals_flat = data_flat.reshape(-1)
|
|
308
|
+
|
|
309
|
+
out_sum = torch.zeros(R * G, dtype=data.dtype, device=device)
|
|
310
|
+
out_sum.scatter_add_(0, idx_offset_flat, vals_flat)
|
|
311
|
+
|
|
312
|
+
ones = torch.ones_like(vals_flat, dtype=data.dtype, device=device)
|
|
313
|
+
out_cnt = torch.zeros(R * G, dtype=data.dtype, device=device)
|
|
314
|
+
out_cnt.scatter_add_(0, idx_offset_flat, ones)
|
|
315
|
+
out_cnt = torch.clamp(out_cnt, min=1)
|
|
316
|
+
|
|
317
|
+
mean = (out_sum / out_cnt).view(R, G).view(*original_shape, G)
|
|
318
|
+
|
|
319
|
+
# On retourne les VRAIS ids HEALPix parents à nside/2
|
|
320
|
+
return mean, unique_groups
|
|
321
|
+
'''
|
|
130
322
|
def average_by_cell_group(data, cell_ids):
|
|
131
323
|
"""
|
|
132
324
|
data: tensor of shape [..., N, ...] (ex: [B, N, C])
|