foscat 2025.9.1__tar.gz → 2025.9.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {foscat-2025.9.1/src/foscat.egg-info → foscat-2025.9.4}/PKG-INFO +1 -1
  2. {foscat-2025.9.1 → foscat-2025.9.4}/pyproject.toml +1 -1
  3. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/BkTorch.py +160 -93
  4. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/FoCUS.py +80 -267
  5. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/HOrientedConvol.py +233 -250
  6. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/HealBili.py +12 -8
  7. foscat-2025.9.4/src/foscat/Plot.py +1298 -0
  8. foscat-2025.9.4/src/foscat/SphericalStencil.py +1346 -0
  9. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/UNET.py +21 -7
  10. foscat-2025.9.4/src/foscat/healpix_unet_torch.py +1202 -0
  11. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/scat_cov.py +2 -0
  12. {foscat-2025.9.1 → foscat-2025.9.4/src/foscat.egg-info}/PKG-INFO +1 -1
  13. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat.egg-info/SOURCES.txt +1 -0
  14. foscat-2025.9.1/src/foscat/Plot.py +0 -328
  15. foscat-2025.9.1/src/foscat/healpix_unet_torch.py +0 -717
  16. {foscat-2025.9.1 → foscat-2025.9.4}/LICENSE +0 -0
  17. {foscat-2025.9.1 → foscat-2025.9.4}/README.md +0 -0
  18. {foscat-2025.9.1 → foscat-2025.9.4}/setup.cfg +0 -0
  19. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/BkBase.py +0 -0
  20. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/BkNumpy.py +0 -0
  21. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/BkTensorflow.py +0 -0
  22. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/CNN.py +0 -0
  23. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/CircSpline.py +0 -0
  24. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/GCNN.py +0 -0
  25. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/HealSpline.py +0 -0
  26. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/Softmax.py +0 -0
  27. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/Spline1D.py +0 -0
  28. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/Synthesis.py +0 -0
  29. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/__init__.py +0 -0
  30. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/alm.py +0 -0
  31. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/backend.py +0 -0
  32. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/backend_tens.py +0 -0
  33. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/heal_NN.py +0 -0
  34. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/loss_backend_tens.py +0 -0
  35. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/loss_backend_torch.py +0 -0
  36. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/scat.py +0 -0
  37. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/scat1D.py +0 -0
  38. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/scat2D.py +0 -0
  39. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/scat_cov1D.py +0 -0
  40. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/scat_cov2D.py +0 -0
  41. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/scat_cov_map.py +0 -0
  42. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat/scat_cov_map2D.py +0 -0
  43. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat.egg-info/dependency_links.txt +0 -0
  44. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat.egg-info/requires.txt +0 -0
  45. {foscat-2025.9.1 → foscat-2025.9.4}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.9.1
3
+ Version: 2025.9.4
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "foscat"
3
- version = "2025.09.1"
3
+ version = "2025.09.4"
4
4
  description = "Generate synthetic Healpix or 2D data using Cross Scattering Transform"
5
5
  readme = "README.md"
6
6
  license = { text = "BSD-3-Clause" }
@@ -70,7 +70,7 @@ class BkTorch(BackendBase.BackendBase):
70
70
  # and batched cell_ids of shape [B, N]. It returns compact per-parent means
71
71
  # even when some parents are missing (sparse coverage).
72
72
 
73
- def binned_mean(self, data, cell_ids, *, padded: bool = False, fill_value: float = float("nan")):
73
+ def binned_mean_old(self, data, cell_ids, *, padded: bool = False, fill_value: float = float("nan")):
74
74
  """Average values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
75
75
 
76
76
  Works with full-sky or sparse subsets (no need for N to be divisible by 4).
@@ -211,114 +211,181 @@ class BkTorch(BackendBase.BackendBase):
211
211
  return mean_pad, groups_pad, mask
212
212
 
213
213
  else:
214
- raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
215
- '''
216
- def binned_mean(self, data, cell_ids):
214
+ raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
215
+
216
+ def binned_mean(
217
+ self,
218
+ data,
219
+ cell_ids,
220
+ *,
221
+ reduce: str = "mean", # <-- NEW: "mean" (par défaut) ou "max"
222
+ padded: bool = False,
223
+ fill_value: float = float("nan"),
224
+ ):
217
225
  """
218
- Moyenne par groupes de 4 pixels HEALPix nested (nside -> nside/2),
219
- fonctionne avec un sous-ensemble arbitraire de pixels.
226
+ Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
220
227
 
221
- Args
222
- ----
223
- data: torch.Tensor | np.ndarray, shape [..., N] ou [B, ..., N]
224
- cell_ids: torch.LongTensor | np.ndarray, shape [N] ou [B, N] (nested)
228
+ Parameters
229
+ ----------
230
+ data : torch.Tensor | np.ndarray
231
+ Shape [..., N] or [B, ..., N].
232
+ cell_ids : torch.LongTensor | np.ndarray
233
+ Shape [N] or [B, N] (nested indexing at the child resolution).
234
+ reduce : {"mean","max"}, default "mean"
235
+ Aggregation to apply within each parent group of 4 children.
236
+ padded : bool, default False
237
+ Only used when `cell_ids` is [B, N]. If False, returns ragged Python lists.
238
+ If True, returns padded tensors + mask.
239
+ fill_value : float, default NaN
240
+ Padding value when `padded=True`.
225
241
 
226
242
  Returns
227
- -------
228
- mean: torch.Tensor, shape [..., G] ou [B, ..., G]
229
- groups_out: torch.LongTensor, shape [G] (ids HEALPix parents à nside/2)
243
+ -------
244
+ # idem à ta doc existante, mais la valeur est une moyenne (reduce="mean")
245
+ # ou un maximum (reduce="max").
230
246
  """
231
-
232
- # --- to tensors on device ---
247
+
248
+ # ---- Tensorize & device/dtype plumbing ----
233
249
  if isinstance(data, np.ndarray):
234
- data = torch.from_numpy(data).to(dtype=torch.float32, device=self.torch_device)
250
+ data = torch.from_numpy(data).to(
251
+ dtype=torch.float32, device=getattr(self, "torch_device", "cpu")
252
+ )
235
253
  if isinstance(cell_ids, np.ndarray):
236
- cell_ids = torch.from_numpy(cell_ids).to(dtype=torch.long, device=self.torch_device)
237
- data = data.to(self.torch_device)
238
- cell_ids = cell_ids.to(self.torch_device, dtype=torch.long)
254
+ cell_ids = torch.from_numpy(cell_ids).to(
255
+ dtype=torch.long, device=data.device
256
+ )
257
+ data = data.to(device=getattr(self, "torch_device", data.device))
258
+ cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
239
259
 
240
- # --- shapes ---
241
260
  if data.ndim < 1:
242
- raise ValueError("`data` must have at least 1 dim; last is N.")
261
+ raise ValueError("`data` must have at least 1 dimension (last is N).")
243
262
  N = data.shape[-1]
244
- if N % 4 != 0:
245
- raise ValueError(f"N={N} must be divisible by 4 for nested groups of 4.")
246
263
 
247
- # --- parent groups @ nside/2, accept [N] or [B,N] ---
264
+ # Utilitaires pour 'max' (fallback si scatter_reduce_ indisponible)
265
+ def _segment_max(vals_flat, idx_flat, out_size):
266
+ """Retourne out[out_idx] = max(vals[ idx==out_idx ]), vectorisé si possible."""
267
+ # PyTorch >= 1.12 / 2.0: scatter_reduce_ disponible
268
+ if hasattr(torch.Tensor, "scatter_reduce_"):
269
+ out = torch.full((out_size,), -float("inf"),
270
+ dtype=vals_flat.dtype, device=vals_flat.device)
271
+ out.scatter_reduce_(0, idx_flat, vals_flat, reduce="amax", include_self=True)
272
+ return out
273
+ # Fallback simple (boucle sur indices uniques) – OK pour du downsample
274
+ out = torch.full((out_size,), -float("inf"),
275
+ dtype=vals_flat.dtype, device=vals_flat.device)
276
+ uniq = torch.unique(idx_flat)
277
+ for u in uniq.tolist():
278
+ m = (idx_flat == u)
279
+ # éviter max() sur tensor vide
280
+ if m.any():
281
+ out[u] = torch.max(vals_flat[m])
282
+ return out
283
+
284
+ # ---- Flatten leading dims for scatter convenience ----
285
+ orig = data.shape[:-1]
248
286
  if cell_ids.ndim == 1:
249
- if cell_ids.shape[0] != N:
250
- raise ValueError(f"cell_ids shape {tuple(cell_ids.shape)} incompatible with N={N}.")
251
- groups_parent = (cell_ids // 4).long() # [N]
252
- # densification -> [0..G-1]
253
- unique_groups, inverse = torch.unique(groups_parent, return_inverse=True) # [G], [N]
254
- # mapping identique pour toutes les lignes/rows de data
255
- B_ids = 1
287
+ # Shared mapping for all rows
288
+ groups = (cell_ids // 4).to(torch.long) # [N]
289
+ parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
290
+ n_bins = parents.numel()
291
+
292
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
293
+ data_flat = data.reshape(R, N) # [R, N]
294
+ row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins
295
+ idx = inv.unsqueeze(0).expand(R, -1) + row_offsets # [R, N]
296
+
297
+ vals_flat = data_flat.reshape(-1)
298
+ idx_flat = idx.reshape(-1)
299
+ out_size = R * n_bins
300
+
301
+ if reduce == "mean":
302
+ out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
303
+ out_cnt = torch.zeros_like(out_sum)
304
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
305
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
306
+ out_cnt.clamp_(min=1)
307
+ reduced = out_sum / out_cnt
308
+ elif reduce == "max":
309
+ reduced = _segment_max(vals_flat, idx_flat, out_size)
310
+ else:
311
+ raise ValueError("reduce must be 'mean' or 'max'.")
312
+
313
+ output = reduced.view(*orig, n_bins)
314
+ return output, parents
315
+
256
316
  elif cell_ids.ndim == 2:
257
- B_ids, N_ids = cell_ids.shape
258
- if N_ids != N:
259
- raise ValueError(f"cell_ids last dim {N_ids} must equal N={N}.")
260
- # vérif compatibilité batch (data doit commencer par B ou par un multiple de B)
261
- leading = data.shape[:-1]
262
- if len(leading) == 0:
263
- raise ValueError("`data` must have a leading dim to match [B, N] cell_ids.")
264
- B_data = leading[0]
265
- if B_data % B_ids != 0:
266
- raise ValueError(f"Leading batch of data ({B_data}) must be a multiple of cell_ids batch ({B_ids}).")
267
-
268
- # Construire un mapping DENSE par batch, mais on impose que la topologie
269
- # des parents soit la même pour tous les batches -> on se base sur le batch 0
270
- groups_parent0 = (cell_ids[0] // 4).long() # [N]
271
- unique_groups, inverse0 = torch.unique(groups_parent0, return_inverse=True) # [G], [N]
272
-
273
- # Vérification (optionnelle mais sûre) : chaque batch a les mêmes parents (ordre potentiellement différent OK)
274
- # -> ici on exige même l'égalité stricte pour éviter les surprises ;
275
- # sinon on pourrait densifier par-batch et retourner une liste de groups_out.
276
- for b in range(1, B_ids):
277
- if not torch.equal(groups_parent0, (cell_ids[b] // 4).long()):
278
- raise ValueError("All batches in cell_ids must share the same parent groups (order & content).")
279
-
280
- # Construire l'inverse pour tous les batches en répliquant celui du batch 0
281
- inverse = inverse0.unsqueeze(0).expand(B_ids, -1) # [B_ids, N]
282
- else:
283
- raise ValueError("`cell_ids` must be [N] or [B, N].")
317
+ # Per-batch mapping
318
+ B = cell_ids.shape[0]
319
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
320
+ data_flat = data.reshape(R, N) # [R, N]
321
+ B_data = data.shape[0] if len(orig) > 0 else 1
322
+ if B_data % B != 0:
323
+ raise ValueError(
324
+ f"Leading dim of data ({B_data}) must be a multiple of cell_ids batch ({B})."
325
+ )
326
+ # T = repeats per batch row (product of extra leading dims)
327
+ T = (R // B_data) if B_data > 0 else 1
284
328
 
285
- G = unique_groups.numel() # nb de groupes parents (nside/2)
329
+ means_list, groups_list = [], []
330
+ max_bins = 0
286
331
 
287
- # --- aplatir data en lignes ---
288
- original_shape = data.shape[:-1] # e.g. [B, D1, ...]
289
- R = int(np.prod(original_shape)) if original_shape else 1
290
- data_flat = data.reshape(R, N) # [R, N]
332
+ for b in range(B):
333
+ groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
334
+ parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
335
+ n_bins_b = parents_b.numel()
336
+ max_bins = max(max_bins, n_bins_b)
337
+
338
+ # rows for this batch in data_flat
339
+ start, stop = b * T, (b + 1) * T
340
+ rows = slice(start, stop) # T rows
341
+
342
+ row_offsets = torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b
343
+ idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
344
+
345
+ vals_flat = data_flat[rows].reshape(-1)
346
+ idx_flat = idx.reshape(-1)
347
+ out_size = T * n_bins_b
348
+
349
+ if reduce == "mean":
350
+ out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
351
+ out_cnt = torch.zeros_like(out_sum)
352
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
353
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
354
+ out_cnt.clamp_(min=1)
355
+ reduced_bt = (out_sum / out_cnt).view(T, n_bins_b)
356
+ elif reduce == "max":
357
+ reduced_bt = _segment_max(vals_flat, idx_flat, out_size).view(T, n_bins_b)
358
+ else:
359
+ raise ValueError("reduce must be 'mean' or 'max'.")
360
+
361
+ means_list.append(reduced_bt)
362
+ groups_list.append(parents_b)
363
+
364
+ if not padded:
365
+ return means_list, groups_list
366
+
367
+ # Padded output (B, T, max_bins) [+ mask]
368
+ mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
369
+ groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
370
+ mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
371
+ for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
372
+ nb = g_b.numel()
373
+ mean_pad[b, :, :nb] = m_b
374
+ groups_pad[b, :nb] = g_b
375
+ mask[b, :nb] = True
376
+
377
+ # Reshape back to [B, (*extra dims), max_bins] si besoin
378
+ if len(orig) > 1:
379
+ extra = orig[1:]
380
+ mean_pad = mean_pad.view(B, *extra, max_bins)
381
+ else:
382
+ mean_pad = mean_pad.view(B, max_bins)
383
+
384
+ return mean_pad, groups_pad, mask
291
385
 
292
- # --- construire indices de bins par ligne ---
293
- if cell_ids.ndim == 1:
294
- idx = inverse.expand(R, -1) # [R, N], dans [0..G-1]
295
386
  else:
296
- # cell_ids est [B_ids, N]; on doit « étirer » chaque ligne de mapping
297
- # pour couvrir les R lignes de data_flat en respectant la 1ère dim (B_data)
298
- B_data = original_shape[0]
299
- T = R // B_data # répétitions par batch-row
300
- idx = inverse.repeat_interleave(T, dim=0) # [B_ids*T, N] == [R, N]
301
-
302
- # --- scatter add (somme et compte) par ligne ---
303
- device = data.device
304
- row_offsets = torch.arange(R, device=device).unsqueeze(1) * G
305
- idx_offset = idx.to(torch.long) + row_offsets # [R,N]
306
- idx_offset_flat = idx_offset.reshape(-1)
307
- vals_flat = data_flat.reshape(-1)
308
-
309
- out_sum = torch.zeros(R * G, dtype=data.dtype, device=device)
310
- out_sum.scatter_add_(0, idx_offset_flat, vals_flat)
311
-
312
- ones = torch.ones_like(vals_flat, dtype=data.dtype, device=device)
313
- out_cnt = torch.zeros(R * G, dtype=data.dtype, device=device)
314
- out_cnt.scatter_add_(0, idx_offset_flat, ones)
315
- out_cnt = torch.clamp(out_cnt, min=1)
316
-
317
- mean = (out_sum / out_cnt).view(R, G).view(*original_shape, G)
318
-
319
- # On retourne les VRAIS ids HEALPix parents à nside/2
320
- return mean, unique_groups
321
- '''
387
+ raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
388
+
322
389
  def average_by_cell_group(data, cell_ids):
323
390
  """
324
391
  data: tensor of shape [..., N, ...] (ex: [B, N, C])
@@ -346,7 +413,7 @@ class BkTorch(BackendBase.BackendBase):
346
413
  return S.numel()
347
414
 
348
415
  def bk_SparseTensor(self, indice, w, dense_shape=[]):
349
- return self.backend.sparse_coo_tensor(indice.T, w, dense_shape).to_sparse_csr().to(self.torch_device)
416
+ return self.backend.sparse_coo_tensor(indice, w, dense_shape).coalesce().to_sparse_csr().to(self.torch_device)
350
417
 
351
418
  def bk_stack(self, list, axis=0):
352
419
  return self.backend.stack(list, axis=axis).to(self.torch_device)