foscat 2025.8.3__tar.gz → 2025.9.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {foscat-2025.8.3/src/foscat.egg-info → foscat-2025.9.1}/PKG-INFO +1 -1
  2. {foscat-2025.8.3 → foscat-2025.9.1}/pyproject.toml +1 -1
  3. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/BkTorch.py +241 -49
  4. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/FoCUS.py +5 -3
  5. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/HOrientedConvol.py +446 -42
  6. foscat-2025.9.1/src/foscat/HealBili.py +305 -0
  7. foscat-2025.9.1/src/foscat/Plot.py +328 -0
  8. foscat-2025.9.1/src/foscat/UNET.py +477 -0
  9. foscat-2025.9.1/src/foscat/healpix_unet_torch.py +717 -0
  10. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/scat_cov.py +42 -30
  11. {foscat-2025.8.3 → foscat-2025.9.1/src/foscat.egg-info}/PKG-INFO +1 -1
  12. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat.egg-info/SOURCES.txt +3 -0
  13. foscat-2025.8.3/src/foscat/UNET.py +0 -200
  14. {foscat-2025.8.3 → foscat-2025.9.1}/LICENSE +0 -0
  15. {foscat-2025.8.3 → foscat-2025.9.1}/README.md +0 -0
  16. {foscat-2025.8.3 → foscat-2025.9.1}/setup.cfg +0 -0
  17. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/BkBase.py +0 -0
  18. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/BkNumpy.py +0 -0
  19. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/BkTensorflow.py +0 -0
  20. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/CNN.py +0 -0
  21. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/CircSpline.py +0 -0
  22. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/GCNN.py +0 -0
  23. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/HealSpline.py +0 -0
  24. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/Softmax.py +0 -0
  25. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/Spline1D.py +0 -0
  26. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/Synthesis.py +0 -0
  27. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/__init__.py +0 -0
  28. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/alm.py +0 -0
  29. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/backend.py +0 -0
  30. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/backend_tens.py +0 -0
  31. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/heal_NN.py +0 -0
  32. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/loss_backend_tens.py +0 -0
  33. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/loss_backend_torch.py +0 -0
  34. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/scat.py +0 -0
  35. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/scat1D.py +0 -0
  36. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/scat2D.py +0 -0
  37. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/scat_cov1D.py +0 -0
  38. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/scat_cov2D.py +0 -0
  39. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/scat_cov_map.py +0 -0
  40. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat/scat_cov_map2D.py +0 -0
  41. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat.egg-info/dependency_links.txt +0 -0
  42. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat.egg-info/requires.txt +0 -0
  43. {foscat-2025.8.3 → foscat-2025.9.1}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.8.3
3
+ Version: 2025.9.1
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "foscat"
3
- version = "2025.08.3"
3
+ version = "2025.09.1"
4
4
  description = "Generate synthetic Healpix or 2D data using Cross Scattering Transform"
5
5
  readme = "README.md"
6
6
  license = { text = "BSD-3-Clause" }
@@ -63,70 +63,262 @@ class BkTorch(BackendBase.BackendBase):
63
63
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
64
64
  )
65
65
 
66
- import torch
67
-
68
- def binned_mean(self, data, cell_ids):
66
+ # ---------------------------------
67
+ # HEALPix binning utilities (nested)
68
+ # ---------------------------------
69
+ # Robust binned_mean that supports arbitrary subsets (N not divisible by 4)
70
+ # and batched cell_ids of shape [B, N]. It returns compact per-parent means
71
+ # even when some parents are missing (sparse coverage).
72
+
73
+ def binned_mean(self, data, cell_ids, *, padded: bool = False, fill_value: float = float("nan")):
74
+ """Average values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
75
+
76
+ Works with full-sky or sparse subsets (no need for N to be divisible by 4).
77
+
78
+ Parameters
79
+ ----------
80
+ data : torch.Tensor or np.ndarray
81
+ Shape ``[..., N]`` or ``[B, ..., N]``.
82
+ cell_ids : torch.LongTensor or np.ndarray
83
+ Shape ``[N]`` or ``[B, N]`` (nested indexing at the *child* resolution).
84
+ padded : bool, optional (default: False)
85
+ Only used when ``cell_ids`` is ``[B, N]``. If ``False``, returns Python
86
+ lists (ragged) of per-batch results. If ``True``, returns padded tensors
87
+ plus a boolean mask of valid bins.
88
+ fill_value : float, optional
89
+ Value used for padding when ``padded=True``.
90
+
91
+ Returns
92
+ -------
93
+ If ``cell_ids`` is ``[N]``:
94
+ mean : torch.Tensor, shape ``[..., n_bins]``
95
+ groups: torch.LongTensor, shape ``[n_bins]`` (sorted unique parents)
96
+
97
+ If ``cell_ids`` is ``[B, N]`` and ``padded=False``:
98
+ means_list : List[torch.Tensor] of length B, each shape ``[T, n_bins_b]``
99
+ where ``T = prod(data.shape[1:-1])`` (or 1 if none).
100
+ groups_list : List[torch.LongTensor] of length B, each shape ``[n_bins_b]``
101
+
102
+ If ``cell_ids`` is ``[B, N]`` and ``padded=True``:
103
+ mean_padded : torch.Tensor, shape ``[B, T, max_bins]`` (or ``[B, max_bins]`` if T==1)
104
+ groups_pad : torch.LongTensor, shape ``[B, max_bins]`` (parents, padded with -1)
105
+ mask : torch.BoolTensor, shape ``[B, max_bins]`` (True where valid)
69
106
  """
70
- Compute the mean over groups of 4 nested HEALPix cells (nside → nside/2).
71
-
72
- Args:
73
- data (torch.Tensor): Tensor of shape [..., N], where N is the number of HEALPix cells.
74
- cell_ids (torch.LongTensor): Tensor of shape [N], with cell indices (nested ordering).
107
+ import torch, numpy as np
75
108
 
76
- Returns:
77
- torch.Tensor: Tensor of shape [..., n_bins], with averaged values per group of 4 cells.
78
- """
109
+ # ---- Tensorize & device/dtype plumbing ----
79
110
  if isinstance(data, np.ndarray):
80
- data = torch.from_numpy(data).to(
81
- dtype=torch.float32, device=self.torch_device
82
- )
111
+ data = torch.from_numpy(data).to(dtype=torch.float32, device=getattr(self, 'torch_device', 'cpu'))
83
112
  if isinstance(cell_ids, np.ndarray):
84
- cell_ids = torch.from_numpy(cell_ids).to(
85
- dtype=torch.long, device=self.torch_device
86
- )
87
-
88
- # Compute supercell ids by grouping 4 nested cells together
89
- groups = cell_ids // 4
113
+ cell_ids = torch.from_numpy(cell_ids).to(dtype=torch.long, device=data.device)
90
114
 
91
- # Get unique group ids and inverse mapping
92
- unique_groups, inverse_indices = torch.unique(groups, return_inverse=True)
93
- n_bins = unique_groups.shape[0]
115
+ data = data.to(device=getattr(self, 'torch_device', data.device))
116
+ cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
94
117
 
95
- # Flatten all leading dimensions into a single batch dimension
96
- original_shape = data.shape[:-1]
118
+ if data.ndim < 1:
119
+ raise ValueError("`data` must have at least 1 dimension (last is N).")
97
120
  N = data.shape[-1]
98
- data_flat = data.reshape(-1, N) # Shape: [B, N]
99
121
 
100
- # Prepare to compute sums using scatter_add
101
- B = data_flat.shape[0]
102
-
103
- # Repeat inverse indices for each batch element
104
- idx = inverse_indices.repeat(B, 1) # Shape: [B, N]
122
+ # Flatten leading dims (rows) for scatter convenience
123
+ orig = data.shape[:-1]
124
+ T = int(np.prod(orig[1:])) if len(orig) > 1 else 1 # repeats per batch row
125
+ if cell_ids.ndim == 1:
126
+ # Shared mapping for all rows
127
+ groups = (cell_ids // 4).to(torch.long) # [N]
128
+ # Unique parent ids + inverse indices
129
+ parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
130
+ n_bins = parents.numel()
131
+
132
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
133
+ data_flat = data.reshape(R, N) # [R, N]
134
+
135
+ # Row offsets -> independent bins per row
136
+ row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins # [R,1]
137
+ idx = inv.unsqueeze(0).expand(R, -1) + row_offsets # [R,N]
138
+
139
+ vals_flat = data_flat.reshape(-1)
140
+ idx_flat = idx.reshape(-1)
141
+
142
+ out_sum = torch.zeros(R * n_bins, dtype=data.dtype, device=data.device)
143
+ out_cnt = torch.zeros_like(out_sum)
144
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
145
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
146
+ out_cnt.clamp_(min=1)
147
+
148
+ mean = (out_sum / out_cnt).view(*orig, n_bins)
149
+ return mean, parents
150
+
151
+ elif cell_ids.ndim == 2:
152
+ B = cell_ids.shape[0]
153
+ if data.shape[0] % B != 0:
154
+ raise ValueError(f"Leading dim of data ({data.shape[0]}) must be a multiple of cell_ids batch ({B}).")
155
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
156
+ data_flat = data.reshape(R, N) # [R, N]
157
+ B_data = data.shape[0]
158
+ T = R // B_data # repeats per batch row (product of extra leading dims)
159
+
160
+ means_list, groups_list = [], []
161
+ max_bins = 0
162
+ # First pass: compute per-batch parents/inv and scatter means
163
+ for b in range(B):
164
+ groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
165
+ parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
166
+ n_bins_b = parents_b.numel()
167
+ max_bins = max(max_bins, n_bins_b)
168
+
169
+ # rows for this batch in data_flat
170
+ start = b * T
171
+ stop = (b + 1) * T
172
+ rows = slice(start, stop) # T rows
173
+
174
+ row_offsets = (torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b)
175
+ idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
176
+
177
+ vals_flat = data_flat[rows].reshape(-1)
178
+ idx_flat = idx.reshape(-1)
179
+
180
+ out_sum = torch.zeros(T * n_bins_b, dtype=data.dtype, device=data.device)
181
+ out_cnt = torch.zeros_like(out_sum)
182
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
183
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
184
+ out_cnt.clamp_(min=1)
185
+ mean_bt = (out_sum / out_cnt).view(T, n_bins_b) # [T, n_bins_b]
186
+
187
+ means_list.append(mean_bt)
188
+ groups_list.append(parents_b)
189
+
190
+ if not padded:
191
+ return means_list, groups_list
192
+
193
+ # Padded output
194
+ # mean_padded: [B, T, max_bins]; groups_pad: [B, max_bins]; mask: [B, max_bins]
195
+ mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
196
+ groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
197
+ mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
198
+ for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
199
+ nb = g_b.numel()
200
+ mean_pad[b, :, :nb] = m_b
201
+ groups_pad[b, :nb] = g_b
202
+ mask[b, :nb] = True
203
+
204
+ # Reshape back to [B, (*extra leading dims), max_bins] if needed
205
+ if len(orig) > 1:
206
+ extra = orig[1:] # e.g., (D1, D2, ...)
207
+ mean_pad = mean_pad.view(B, *extra, max_bins)
208
+ else:
209
+ mean_pad = mean_pad.view(B, max_bins)
105
210
 
106
- # Offset indices to simulate a per-batch scatter into [B * n_bins]
107
- batch_offsets = torch.arange(B, device=data.device).unsqueeze(1) * n_bins
108
- idx_offset = idx + batch_offsets # Shape: [B, N]
211
+ return mean_pad, groups_pad, mask
109
212
 
110
- # Flatten everything for scatter
111
- idx_offset_flat = idx_offset.flatten()
112
- data_flat_flat = data_flat.flatten()
213
+ else:
214
+ raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
215
+ '''
216
+ def binned_mean(self, data, cell_ids):
217
+ """
218
+ Moyenne par groupes de 4 pixels HEALPix nested (nside -> nside/2),
219
+ fonctionne avec un sous-ensemble arbitraire de pixels.
220
+
221
+ Args
222
+ ----
223
+ data: torch.Tensor | np.ndarray, shape [..., N] ou [B, ..., N]
224
+ cell_ids: torch.LongTensor | np.ndarray, shape [N] ou [B, N] (nested)
225
+
226
+ Returns
227
+ -------
228
+ mean: torch.Tensor, shape [..., G] ou [B, ..., G]
229
+ groups_out: torch.LongTensor, shape [G] (ids HEALPix parents à nside/2)
230
+ """
113
231
 
114
- # Accumulate sums per bin
115
- out = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
116
- out = out.scatter_add(0, idx_offset_flat, data_flat_flat)
232
+ # --- to tensors on device ---
233
+ if isinstance(data, np.ndarray):
234
+ data = torch.from_numpy(data).to(dtype=torch.float32, device=self.torch_device)
235
+ if isinstance(cell_ids, np.ndarray):
236
+ cell_ids = torch.from_numpy(cell_ids).to(dtype=torch.long, device=self.torch_device)
237
+ data = data.to(self.torch_device)
238
+ cell_ids = cell_ids.to(self.torch_device, dtype=torch.long)
117
239
 
118
- # Count number of elements per bin (to compute mean)
119
- ones = torch.ones_like(data_flat_flat)
120
- counts = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
121
- counts = counts.scatter_add(0, idx_offset_flat, ones)
240
+ # --- shapes ---
241
+ if data.ndim < 1:
242
+ raise ValueError("`data` must have at least 1 dim; last is N.")
243
+ N = data.shape[-1]
244
+ if N % 4 != 0:
245
+ raise ValueError(f"N={N} must be divisible by 4 for nested groups of 4.")
246
+
247
+ # --- parent groups @ nside/2, accept [N] or [B,N] ---
248
+ if cell_ids.ndim == 1:
249
+ if cell_ids.shape[0] != N:
250
+ raise ValueError(f"cell_ids shape {tuple(cell_ids.shape)} incompatible with N={N}.")
251
+ groups_parent = (cell_ids // 4).long() # [N]
252
+ # densification -> [0..G-1]
253
+ unique_groups, inverse = torch.unique(groups_parent, return_inverse=True) # [G], [N]
254
+ # mapping identique pour toutes les lignes/rows de data
255
+ B_ids = 1
256
+ elif cell_ids.ndim == 2:
257
+ B_ids, N_ids = cell_ids.shape
258
+ if N_ids != N:
259
+ raise ValueError(f"cell_ids last dim {N_ids} must equal N={N}.")
260
+ # vérif compatibilité batch (data doit commencer par B ou par un multiple de B)
261
+ leading = data.shape[:-1]
262
+ if len(leading) == 0:
263
+ raise ValueError("`data` must have a leading dim to match [B, N] cell_ids.")
264
+ B_data = leading[0]
265
+ if B_data % B_ids != 0:
266
+ raise ValueError(f"Leading batch of data ({B_data}) must be a multiple of cell_ids batch ({B_ids}).")
267
+
268
+ # Construire un mapping DENSE par batch, mais on impose que la topologie
269
+ # des parents soit la même pour tous les batches -> on se base sur le batch 0
270
+ groups_parent0 = (cell_ids[0] // 4).long() # [N]
271
+ unique_groups, inverse0 = torch.unique(groups_parent0, return_inverse=True) # [G], [N]
272
+
273
+ # Vérification (optionnelle mais sûre) : chaque batch a les mêmes parents (ordre potentiellement différent OK)
274
+ # -> ici on exige même l'égalité stricte pour éviter les surprises ;
275
+ # sinon on pourrait densifier par-batch et retourner une liste de groups_out.
276
+ for b in range(1, B_ids):
277
+ if not torch.equal(groups_parent0, (cell_ids[b] // 4).long()):
278
+ raise ValueError("All batches in cell_ids must share the same parent groups (order & content).")
279
+
280
+ # Construire l'inverse pour tous les batches en répliquant celui du batch 0
281
+ inverse = inverse0.unsqueeze(0).expand(B_ids, -1) # [B_ids, N]
282
+ else:
283
+ raise ValueError("`cell_ids` must be [N] or [B, N].")
122
284
 
123
- # Compute mean
124
- mean = out / counts # Shape: [B * n_bins]
125
- mean = mean.view(B, n_bins)
285
+ G = unique_groups.numel() # nb de groupes parents (nside/2)
126
286
 
127
- # Restore original leading dimensions
128
- return mean.view(*original_shape, n_bins), unique_groups
287
+ # --- aplatir data en lignes ---
288
+ original_shape = data.shape[:-1] # e.g. [B, D1, ...]
289
+ R = int(np.prod(original_shape)) if original_shape else 1
290
+ data_flat = data.reshape(R, N) # [R, N]
129
291
 
292
+ # --- construire indices de bins par ligne ---
293
+ if cell_ids.ndim == 1:
294
+ idx = inverse.expand(R, -1) # [R, N], dans [0..G-1]
295
+ else:
296
+ # cell_ids est [B_ids, N]; on doit « étirer » chaque ligne de mapping
297
+ # pour couvrir les R lignes de data_flat en respectant la 1ère dim (B_data)
298
+ B_data = original_shape[0]
299
+ T = R // B_data # répétitions par batch-row
300
+ idx = inverse.repeat_interleave(T, dim=0) # [B_ids*T, N] == [R, N]
301
+
302
+ # --- scatter add (somme et compte) par ligne ---
303
+ device = data.device
304
+ row_offsets = torch.arange(R, device=device).unsqueeze(1) * G
305
+ idx_offset = idx.to(torch.long) + row_offsets # [R,N]
306
+ idx_offset_flat = idx_offset.reshape(-1)
307
+ vals_flat = data_flat.reshape(-1)
308
+
309
+ out_sum = torch.zeros(R * G, dtype=data.dtype, device=device)
310
+ out_sum.scatter_add_(0, idx_offset_flat, vals_flat)
311
+
312
+ ones = torch.ones_like(vals_flat, dtype=data.dtype, device=device)
313
+ out_cnt = torch.zeros(R * G, dtype=data.dtype, device=device)
314
+ out_cnt.scatter_add_(0, idx_offset_flat, ones)
315
+ out_cnt = torch.clamp(out_cnt, min=1)
316
+
317
+ mean = (out_sum / out_cnt).view(R, G).view(*original_shape, G)
318
+
319
+ # On retourne les VRAIS ids HEALPix parents à nside/2
320
+ return mean, unique_groups
321
+ '''
130
322
  def average_by_cell_group(data, cell_ids):
131
323
  """
132
324
  data: tensor of shape [..., N, ...] (ex: [B, N, C])
@@ -36,7 +36,7 @@ class FoCUS:
36
36
  mpi_rank=0
37
37
  ):
38
38
 
39
- self.__version__ = "2025.08.3"
39
+ self.__version__ = "2025.09.1"
40
40
  # P00 coeff for normalization for scat_cov
41
41
  self.TMPFILE_VERSION = TMPFILE_VERSION
42
42
  self.P1_dic = None
@@ -2380,9 +2380,11 @@ class FoCUS:
2380
2380
  res = v1 / vh
2381
2381
 
2382
2382
  oshape = [x.shape[0]] + [mask.shape[0]]
2383
- if len(x.shape)>1:
2383
+ if len(x.shape) > 3:
2384
2384
  oshape = oshape + list(x.shape[1:-2])
2385
-
2385
+ else:
2386
+ oshape = oshape + [1]
2387
+
2386
2388
  if calc_var:
2387
2389
  if self.backend.bk_is_complex(vtmp):
2388
2390
  res2 = self.backend.bk_sqrt(