foscat 2025.9.1__tar.gz → 2025.9.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {foscat-2025.9.1/src/foscat.egg-info → foscat-2025.9.3}/PKG-INFO +1 -1
- {foscat-2025.9.1 → foscat-2025.9.3}/pyproject.toml +1 -1
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/BkTorch.py +160 -93
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/FoCUS.py +74 -267
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/HOrientedConvol.py +233 -250
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/HealBili.py +12 -8
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/Plot.py +9 -6
- foscat-2025.9.3/src/foscat/SphericalStencil.py +1346 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/UNET.py +21 -7
- foscat-2025.9.3/src/foscat/healpix_unet_torch.py +1202 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/scat_cov.py +2 -0
- {foscat-2025.9.1 → foscat-2025.9.3/src/foscat.egg-info}/PKG-INFO +1 -1
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat.egg-info/SOURCES.txt +1 -0
- foscat-2025.9.1/src/foscat/healpix_unet_torch.py +0 -717
- {foscat-2025.9.1 → foscat-2025.9.3}/LICENSE +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/README.md +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/setup.cfg +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/BkBase.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/BkNumpy.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/BkTensorflow.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/CNN.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/CircSpline.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/GCNN.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/HealSpline.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/Softmax.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/Spline1D.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/Synthesis.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/__init__.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/alm.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/backend.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/backend_tens.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/heal_NN.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/loss_backend_tens.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/loss_backend_torch.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/scat.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/scat1D.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/scat2D.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/scat_cov1D.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/scat_cov2D.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/scat_cov_map.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat/scat_cov_map2D.py +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat.egg-info/dependency_links.txt +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat.egg-info/requires.txt +0 -0
- {foscat-2025.9.1 → foscat-2025.9.3}/src/foscat.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: foscat
|
|
3
|
-
Version: 2025.9.
|
|
3
|
+
Version: 2025.9.3
|
|
4
4
|
Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
|
|
5
5
|
Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
|
|
6
6
|
Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
|
|
@@ -70,7 +70,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
70
70
|
# and batched cell_ids of shape [B, N]. It returns compact per-parent means
|
|
71
71
|
# even when some parents are missing (sparse coverage).
|
|
72
72
|
|
|
73
|
-
def
|
|
73
|
+
def binned_mean_old(self, data, cell_ids, *, padded: bool = False, fill_value: float = float("nan")):
|
|
74
74
|
"""Average values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
|
|
75
75
|
|
|
76
76
|
Works with full-sky or sparse subsets (no need for N to be divisible by 4).
|
|
@@ -211,114 +211,181 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
211
211
|
return mean_pad, groups_pad, mask
|
|
212
212
|
|
|
213
213
|
else:
|
|
214
|
-
raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
|
|
215
|
-
|
|
216
|
-
def binned_mean(
|
|
214
|
+
raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
|
|
215
|
+
|
|
216
|
+
def binned_mean(
|
|
217
|
+
self,
|
|
218
|
+
data,
|
|
219
|
+
cell_ids,
|
|
220
|
+
*,
|
|
221
|
+
reduce: str = "mean", # <-- NEW: "mean" (par défaut) ou "max"
|
|
222
|
+
padded: bool = False,
|
|
223
|
+
fill_value: float = float("nan"),
|
|
224
|
+
):
|
|
217
225
|
"""
|
|
218
|
-
|
|
219
|
-
fonctionne avec un sous-ensemble arbitraire de pixels.
|
|
226
|
+
Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
|
|
220
227
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
data : torch.Tensor | np.ndarray
|
|
231
|
+
Shape [..., N] or [B, ..., N].
|
|
232
|
+
cell_ids : torch.LongTensor | np.ndarray
|
|
233
|
+
Shape [N] or [B, N] (nested indexing at the child resolution).
|
|
234
|
+
reduce : {"mean","max"}, default "mean"
|
|
235
|
+
Aggregation to apply within each parent group of 4 children.
|
|
236
|
+
padded : bool, default False
|
|
237
|
+
Only used when `cell_ids` is [B, N]. If False, returns ragged Python lists.
|
|
238
|
+
If True, returns padded tensors + mask.
|
|
239
|
+
fill_value : float, default NaN
|
|
240
|
+
Padding value when `padded=True`.
|
|
225
241
|
|
|
226
242
|
Returns
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
243
|
+
-------
|
|
244
|
+
# idem à ta doc existante, mais la valeur est une moyenne (reduce="mean")
|
|
245
|
+
# ou un maximum (reduce="max").
|
|
230
246
|
"""
|
|
231
|
-
|
|
232
|
-
#
|
|
247
|
+
|
|
248
|
+
# ---- Tensorize & device/dtype plumbing ----
|
|
233
249
|
if isinstance(data, np.ndarray):
|
|
234
|
-
data = torch.from_numpy(data).to(
|
|
250
|
+
data = torch.from_numpy(data).to(
|
|
251
|
+
dtype=torch.float32, device=getattr(self, "torch_device", "cpu")
|
|
252
|
+
)
|
|
235
253
|
if isinstance(cell_ids, np.ndarray):
|
|
236
|
-
cell_ids = torch.from_numpy(cell_ids).to(
|
|
237
|
-
|
|
238
|
-
|
|
254
|
+
cell_ids = torch.from_numpy(cell_ids).to(
|
|
255
|
+
dtype=torch.long, device=data.device
|
|
256
|
+
)
|
|
257
|
+
data = data.to(device=getattr(self, "torch_device", data.device))
|
|
258
|
+
cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
|
|
239
259
|
|
|
240
|
-
# --- shapes ---
|
|
241
260
|
if data.ndim < 1:
|
|
242
|
-
raise ValueError("`data` must have at least 1
|
|
261
|
+
raise ValueError("`data` must have at least 1 dimension (last is N).")
|
|
243
262
|
N = data.shape[-1]
|
|
244
|
-
if N % 4 != 0:
|
|
245
|
-
raise ValueError(f"N={N} must be divisible by 4 for nested groups of 4.")
|
|
246
263
|
|
|
247
|
-
#
|
|
264
|
+
# Utilitaires pour 'max' (fallback si scatter_reduce_ indisponible)
|
|
265
|
+
def _segment_max(vals_flat, idx_flat, out_size):
|
|
266
|
+
"""Retourne out[out_idx] = max(vals[ idx==out_idx ]), vectorisé si possible."""
|
|
267
|
+
# PyTorch >= 1.12 / 2.0: scatter_reduce_ disponible
|
|
268
|
+
if hasattr(torch.Tensor, "scatter_reduce_"):
|
|
269
|
+
out = torch.full((out_size,), -float("inf"),
|
|
270
|
+
dtype=vals_flat.dtype, device=vals_flat.device)
|
|
271
|
+
out.scatter_reduce_(0, idx_flat, vals_flat, reduce="amax", include_self=True)
|
|
272
|
+
return out
|
|
273
|
+
# Fallback simple (boucle sur indices uniques) – OK pour du downsample
|
|
274
|
+
out = torch.full((out_size,), -float("inf"),
|
|
275
|
+
dtype=vals_flat.dtype, device=vals_flat.device)
|
|
276
|
+
uniq = torch.unique(idx_flat)
|
|
277
|
+
for u in uniq.tolist():
|
|
278
|
+
m = (idx_flat == u)
|
|
279
|
+
# éviter max() sur tensor vide
|
|
280
|
+
if m.any():
|
|
281
|
+
out[u] = torch.max(vals_flat[m])
|
|
282
|
+
return out
|
|
283
|
+
|
|
284
|
+
# ---- Flatten leading dims for scatter convenience ----
|
|
285
|
+
orig = data.shape[:-1]
|
|
248
286
|
if cell_ids.ndim == 1:
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
287
|
+
# Shared mapping for all rows
|
|
288
|
+
groups = (cell_ids // 4).to(torch.long) # [N]
|
|
289
|
+
parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
|
|
290
|
+
n_bins = parents.numel()
|
|
291
|
+
|
|
292
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
293
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
294
|
+
row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins
|
|
295
|
+
idx = inv.unsqueeze(0).expand(R, -1) + row_offsets # [R, N]
|
|
296
|
+
|
|
297
|
+
vals_flat = data_flat.reshape(-1)
|
|
298
|
+
idx_flat = idx.reshape(-1)
|
|
299
|
+
out_size = R * n_bins
|
|
300
|
+
|
|
301
|
+
if reduce == "mean":
|
|
302
|
+
out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
|
|
303
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
304
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
305
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
306
|
+
out_cnt.clamp_(min=1)
|
|
307
|
+
reduced = out_sum / out_cnt
|
|
308
|
+
elif reduce == "max":
|
|
309
|
+
reduced = _segment_max(vals_flat, idx_flat, out_size)
|
|
310
|
+
else:
|
|
311
|
+
raise ValueError("reduce must be 'mean' or 'max'.")
|
|
312
|
+
|
|
313
|
+
output = reduced.view(*orig, n_bins)
|
|
314
|
+
return output, parents
|
|
315
|
+
|
|
256
316
|
elif cell_ids.ndim == 2:
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
if
|
|
263
|
-
raise ValueError(
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
# Construire un mapping DENSE par batch, mais on impose que la topologie
|
|
269
|
-
# des parents soit la même pour tous les batches -> on se base sur le batch 0
|
|
270
|
-
groups_parent0 = (cell_ids[0] // 4).long() # [N]
|
|
271
|
-
unique_groups, inverse0 = torch.unique(groups_parent0, return_inverse=True) # [G], [N]
|
|
272
|
-
|
|
273
|
-
# Vérification (optionnelle mais sûre) : chaque batch a les mêmes parents (ordre potentiellement différent OK)
|
|
274
|
-
# -> ici on exige même l'égalité stricte pour éviter les surprises ;
|
|
275
|
-
# sinon on pourrait densifier par-batch et retourner une liste de groups_out.
|
|
276
|
-
for b in range(1, B_ids):
|
|
277
|
-
if not torch.equal(groups_parent0, (cell_ids[b] // 4).long()):
|
|
278
|
-
raise ValueError("All batches in cell_ids must share the same parent groups (order & content).")
|
|
279
|
-
|
|
280
|
-
# Construire l'inverse pour tous les batches en répliquant celui du batch 0
|
|
281
|
-
inverse = inverse0.unsqueeze(0).expand(B_ids, -1) # [B_ids, N]
|
|
282
|
-
else:
|
|
283
|
-
raise ValueError("`cell_ids` must be [N] or [B, N].")
|
|
317
|
+
# Per-batch mapping
|
|
318
|
+
B = cell_ids.shape[0]
|
|
319
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
320
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
321
|
+
B_data = data.shape[0] if len(orig) > 0 else 1
|
|
322
|
+
if B_data % B != 0:
|
|
323
|
+
raise ValueError(
|
|
324
|
+
f"Leading dim of data ({B_data}) must be a multiple of cell_ids batch ({B})."
|
|
325
|
+
)
|
|
326
|
+
# T = repeats per batch row (product of extra leading dims)
|
|
327
|
+
T = (R // B_data) if B_data > 0 else 1
|
|
284
328
|
|
|
285
|
-
|
|
329
|
+
means_list, groups_list = [], []
|
|
330
|
+
max_bins = 0
|
|
286
331
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
332
|
+
for b in range(B):
|
|
333
|
+
groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
|
|
334
|
+
parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
|
|
335
|
+
n_bins_b = parents_b.numel()
|
|
336
|
+
max_bins = max(max_bins, n_bins_b)
|
|
337
|
+
|
|
338
|
+
# rows for this batch in data_flat
|
|
339
|
+
start, stop = b * T, (b + 1) * T
|
|
340
|
+
rows = slice(start, stop) # T rows
|
|
341
|
+
|
|
342
|
+
row_offsets = torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b
|
|
343
|
+
idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
|
|
344
|
+
|
|
345
|
+
vals_flat = data_flat[rows].reshape(-1)
|
|
346
|
+
idx_flat = idx.reshape(-1)
|
|
347
|
+
out_size = T * n_bins_b
|
|
348
|
+
|
|
349
|
+
if reduce == "mean":
|
|
350
|
+
out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
|
|
351
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
352
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
353
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
354
|
+
out_cnt.clamp_(min=1)
|
|
355
|
+
reduced_bt = (out_sum / out_cnt).view(T, n_bins_b)
|
|
356
|
+
elif reduce == "max":
|
|
357
|
+
reduced_bt = _segment_max(vals_flat, idx_flat, out_size).view(T, n_bins_b)
|
|
358
|
+
else:
|
|
359
|
+
raise ValueError("reduce must be 'mean' or 'max'.")
|
|
360
|
+
|
|
361
|
+
means_list.append(reduced_bt)
|
|
362
|
+
groups_list.append(parents_b)
|
|
363
|
+
|
|
364
|
+
if not padded:
|
|
365
|
+
return means_list, groups_list
|
|
366
|
+
|
|
367
|
+
# Padded output (B, T, max_bins) [+ mask]
|
|
368
|
+
mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
|
|
369
|
+
groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
|
|
370
|
+
mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
|
|
371
|
+
for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
|
|
372
|
+
nb = g_b.numel()
|
|
373
|
+
mean_pad[b, :, :nb] = m_b
|
|
374
|
+
groups_pad[b, :nb] = g_b
|
|
375
|
+
mask[b, :nb] = True
|
|
376
|
+
|
|
377
|
+
# Reshape back to [B, (*extra dims), max_bins] si besoin
|
|
378
|
+
if len(orig) > 1:
|
|
379
|
+
extra = orig[1:]
|
|
380
|
+
mean_pad = mean_pad.view(B, *extra, max_bins)
|
|
381
|
+
else:
|
|
382
|
+
mean_pad = mean_pad.view(B, max_bins)
|
|
383
|
+
|
|
384
|
+
return mean_pad, groups_pad, mask
|
|
291
385
|
|
|
292
|
-
# --- construire indices de bins par ligne ---
|
|
293
|
-
if cell_ids.ndim == 1:
|
|
294
|
-
idx = inverse.expand(R, -1) # [R, N], dans [0..G-1]
|
|
295
386
|
else:
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
B_data = original_shape[0]
|
|
299
|
-
T = R // B_data # répétitions par batch-row
|
|
300
|
-
idx = inverse.repeat_interleave(T, dim=0) # [B_ids*T, N] == [R, N]
|
|
301
|
-
|
|
302
|
-
# --- scatter add (somme et compte) par ligne ---
|
|
303
|
-
device = data.device
|
|
304
|
-
row_offsets = torch.arange(R, device=device).unsqueeze(1) * G
|
|
305
|
-
idx_offset = idx.to(torch.long) + row_offsets # [R,N]
|
|
306
|
-
idx_offset_flat = idx_offset.reshape(-1)
|
|
307
|
-
vals_flat = data_flat.reshape(-1)
|
|
308
|
-
|
|
309
|
-
out_sum = torch.zeros(R * G, dtype=data.dtype, device=device)
|
|
310
|
-
out_sum.scatter_add_(0, idx_offset_flat, vals_flat)
|
|
311
|
-
|
|
312
|
-
ones = torch.ones_like(vals_flat, dtype=data.dtype, device=device)
|
|
313
|
-
out_cnt = torch.zeros(R * G, dtype=data.dtype, device=device)
|
|
314
|
-
out_cnt.scatter_add_(0, idx_offset_flat, ones)
|
|
315
|
-
out_cnt = torch.clamp(out_cnt, min=1)
|
|
316
|
-
|
|
317
|
-
mean = (out_sum / out_cnt).view(R, G).view(*original_shape, G)
|
|
318
|
-
|
|
319
|
-
# On retourne les VRAIS ids HEALPix parents à nside/2
|
|
320
|
-
return mean, unique_groups
|
|
321
|
-
'''
|
|
387
|
+
raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
|
|
388
|
+
|
|
322
389
|
def average_by_cell_group(data, cell_ids):
|
|
323
390
|
"""
|
|
324
391
|
data: tensor of shape [..., N, ...] (ex: [B, N, C])
|
|
@@ -346,7 +413,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
346
413
|
return S.numel()
|
|
347
414
|
|
|
348
415
|
def bk_SparseTensor(self, indice, w, dense_shape=[]):
|
|
349
|
-
return self.backend.sparse_coo_tensor(indice
|
|
416
|
+
return self.backend.sparse_coo_tensor(indice, w, dense_shape).coalesce().to_sparse_csr().to(self.torch_device)
|
|
350
417
|
|
|
351
418
|
def bk_stack(self, list, axis=0):
|
|
352
419
|
return self.backend.stack(list, axis=axis).to(self.torch_device)
|