foscat 3.9.0__tar.gz → 2025.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {foscat-3.9.0/src/foscat.egg-info → foscat-2025.3.0}/PKG-INFO +3 -2
  2. {foscat-3.9.0 → foscat-2025.3.0}/pyproject.toml +1 -1
  3. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/BkTorch.py +68 -0
  4. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/FoCUS.py +157 -34
  5. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/Synthesis.py +1 -1
  6. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/scat_cov.py +393 -238
  7. {foscat-3.9.0 → foscat-2025.3.0/src/foscat.egg-info}/PKG-INFO +3 -2
  8. {foscat-3.9.0 → foscat-2025.3.0}/LICENSE +0 -0
  9. {foscat-3.9.0 → foscat-2025.3.0}/README.md +0 -0
  10. {foscat-3.9.0 → foscat-2025.3.0}/setup.cfg +0 -0
  11. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/BkBase.py +0 -0
  12. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/BkNumpy.py +0 -0
  13. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/BkTensorflow.py +0 -0
  14. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/CNN.py +0 -0
  15. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/CircSpline.py +0 -0
  16. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/GCNN.py +0 -0
  17. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/Softmax.py +0 -0
  18. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/Spline1D.py +0 -0
  19. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/__init__.py +0 -0
  20. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/alm.py +0 -0
  21. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/backend.py +0 -0
  22. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/backend_tens.py +0 -0
  23. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/loss_backend_tens.py +0 -0
  24. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/loss_backend_torch.py +0 -0
  25. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/scat.py +0 -0
  26. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/scat1D.py +0 -0
  27. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/scat2D.py +0 -0
  28. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/scat_cov1D.py +0 -0
  29. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/scat_cov2D.py +0 -0
  30. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/scat_cov_map.py +0 -0
  31. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat/scat_cov_map2D.py +0 -0
  32. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat.egg-info/SOURCES.txt +0 -0
  33. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat.egg-info/dependency_links.txt +0 -0
  34. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat.egg-info/requires.txt +0 -0
  35. {foscat-3.9.0 → foscat-2025.3.0}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 3.9.0
3
+ Version: 2025.3.0
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -25,6 +25,7 @@ Requires-Dist: matplotlib
25
25
  Requires-Dist: numpy
26
26
  Requires-Dist: healpy
27
27
  Requires-Dist: spherical
28
+ Dynamic: license-file
28
29
 
29
30
  # foscat
30
31
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "foscat"
3
- version = "3.9.0"
3
+ version = "2025.03.0"
4
4
  description = "Generate synthetic Healpix or 2D data using Cross Scattering Transform"
5
5
  readme = "README.md"
6
6
  license = { text = "BSD-3-Clause" }
@@ -62,6 +62,74 @@ class BkTorch(BackendBase.BackendBase):
62
62
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
63
63
  )
64
64
 
65
+ def binned_mean(self, data, cell_ids):
66
+ """
67
+ data: Tensor of shape [B, N, A]
68
+ I: Tensor of shape [N], integer indices in [0, n_bins)
69
+ Returns: mean per bin, shape [B, n_bins, A]
70
+ """
71
+ groups = cell_ids // 4 # [N]
72
+
73
+ unique_groups, I = np.unique(groups, return_inverse=True)
74
+
75
+ n_bins = unique_groups.shape[0]
76
+
77
+ B = data.shape[0]
78
+
79
+ counts = torch.bincount(torch.tensor(I).to(data.device))[None, :]
80
+
81
+ I = np.tile(I, B) + np.tile(n_bins * np.arange(B, dtype="int"), data.shape[1])
82
+
83
+ if len(data.shape) == 3:
84
+ A = data.shape[2]
85
+ I = np.repeat(I, A) * A + np.repeat(
86
+ np.arange(A, dtype="int"), data.shape[1] * B
87
+ )
88
+
89
+ I = torch.tensor(I).to(data.device)
90
+
91
+ # Comptage par bin
92
+ if len(data.shape) == 2:
93
+ sum_per_bin = torch.zeros(
94
+ [B * n_bins], dtype=data.dtype, device=data.device
95
+ )
96
+ sum_per_bin = sum_per_bin.scatter_add(
97
+ 0, I, self.bk_reshape(data, B * data.shape[1])
98
+ ).reshape(B, n_bins)
99
+
100
+ mean_per_bin = sum_per_bin / counts # [B, n_bins, A]
101
+ else:
102
+ sum_per_bin = torch.zeros(
103
+ [B * n_bins * A], dtype=data.dtype, device=data.device
104
+ )
105
+ sum_per_bin = sum_per_bin.scatter_add(
106
+ 0, I, self.bk_reshape(data, B * data.shape[1] * A)
107
+ ).reshape(
108
+ B, n_bins, A
109
+ ) # [B, n_bins]
110
+
111
+ mean_per_bin = sum_per_bin / counts[:, :, None] # [B, n_bins, A]
112
+
113
+ return mean_per_bin, unique_groups
114
+
115
+ def average_by_cell_group(data, cell_ids):
116
+ """
117
+ data: tensor of shape [..., N, ...] (ex: [B, N, C])
118
+ cell_ids: tensor of shape [N]
119
+ Returns: mean_data of shape [..., G, ...] where G = number of unique cell_ids//4
120
+ """
121
+ original_shape = data.shape
122
+ leading = data.shape[:-2] # all dims before N
123
+ N = data.shape[-2]
124
+ trailing = data.shape[-1:] # all dims after N
125
+
126
+ groups = (cell_ids // 4).long() # [N]
127
+ unique_groups, group_indices, counts = torch.unique(
128
+ groups, return_inverse=True, return_counts=True
129
+ )
130
+
131
+ return torch.bincount(group_indices, weights=data) / counts, unique_groups
132
+
65
133
  # ---------------------------------------------−---------
66
134
  # -- BACKEND DEFINITION --
67
135
  # ---------------------------------------------−---------
@@ -35,7 +35,7 @@ class FoCUS:
35
35
  mpi_rank=0,
36
36
  ):
37
37
 
38
- self.__version__ = "3.9.0"
38
+ self.__version__ = "2025.03.0"
39
39
  # P00 coeff for normalization for scat_cov
40
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
41
  self.P1_dic = None
@@ -50,6 +50,12 @@ class FoCUS:
50
50
  self.return_data = return_data
51
51
  self.silent = silent
52
52
 
53
+ self.kernel_smooth = {}
54
+ self.padding_smooth = {}
55
+ self.kernelR_conv = {}
56
+ self.kernelI_conv = {}
57
+ self.padding_conv = {}
58
+
53
59
  if not self.silent:
54
60
  print("================================================")
55
61
  print(" START FOSCAT CONFIGURATION")
@@ -68,10 +74,7 @@ class FoCUS:
68
74
  if not self.silent:
69
75
  print("The directory %s is created")
70
76
  except:
71
- if not self.silent:
72
- print(
73
- "Impossible to create the directory %s" % (self.TEMPLATE_PATH)
74
- )
77
+ print("Impossible to create the directory %s" % (self.TEMPLATE_PATH))
75
78
  return None
76
79
 
77
80
  self.number_of_loss = 0
@@ -81,10 +84,9 @@ class FoCUS:
81
84
  self.padding = padding
82
85
 
83
86
  if JmaxDelta != 0:
84
- if not self.silent:
85
- print(
86
- "OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
87
- )
87
+ print(
88
+ "OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
89
+ )
88
90
  return None
89
91
 
90
92
  self.OSTEP = JmaxDelta
@@ -742,14 +744,14 @@ class FoCUS:
742
744
  return rim
743
745
 
744
746
  # --------------------------------------------------------
745
- def ud_grade_2(self, im, axis=0):
747
+ def ud_grade_2(self, im, axis=0, cell_ids=None, nside=None):
746
748
 
747
749
  if self.use_2D:
748
750
  ishape = list(im.shape)
749
751
  if len(ishape) < axis + 2:
750
752
  if not self.silent:
751
753
  print("Use of 2D scat with data that has less than 2D")
752
- return None
754
+ return None, None
753
755
 
754
756
  npix = im.shape[axis]
755
757
  npiy = im.shape[axis + 1]
@@ -774,29 +776,40 @@ class FoCUS:
774
776
 
775
777
  if axis == 0:
776
778
  if len(ishape) == 2:
777
- return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
779
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
778
780
  else:
779
- return self.backend.bk_reshape(
780
- res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
781
+ return (
782
+ self.backend.bk_reshape(
783
+ res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
784
+ ),
785
+ None,
781
786
  )
782
787
  else:
783
788
  if len(ishape) == axis + 2:
784
- return self.backend.bk_reshape(
785
- res, ishape[0:axis] + [npix // 2, npiy // 2]
789
+ return (
790
+ self.backend.bk_reshape(
791
+ res, ishape[0:axis] + [npix // 2, npiy // 2]
792
+ ),
793
+ None,
786
794
  )
787
795
  else:
788
- return self.backend.bk_reshape(
789
- res,
790
- ishape[0:axis] + [npix // 2, npiy // 2] + ishape[axis + 2 :],
796
+ return (
797
+ self.backend.bk_reshape(
798
+ res,
799
+ ishape[0:axis]
800
+ + [npix // 2, npiy // 2]
801
+ + ishape[axis + 2 :],
802
+ ),
803
+ None,
791
804
  )
792
805
 
793
- return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
806
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
794
807
  elif self.use_1D:
795
808
  ishape = list(im.shape)
796
809
  if len(ishape) < axis + 1:
797
810
  if not self.silent:
798
811
  print("Use of 1D scat with data that has less than 1D")
799
- return None
812
+ return None, None
800
813
 
801
814
  npix = im.shape[axis]
802
815
  odata = 1
@@ -819,23 +832,33 @@ class FoCUS:
819
832
 
820
833
  if axis == 0:
821
834
  if len(ishape) == 1:
822
- return self.backend.bk_reshape(res, [npix // 2])
835
+ return self.backend.bk_reshape(res, [npix // 2]), None
823
836
  else:
824
- return self.backend.bk_reshape(
825
- res, [npix // 2] + ishape[axis + 1 :]
837
+ return (
838
+ self.backend.bk_reshape(res, [npix // 2] + ishape[axis + 1 :]),
839
+ None,
826
840
  )
827
841
  else:
828
842
  if len(ishape) == axis + 1:
829
- return self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2])
843
+ return (
844
+ self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2]),
845
+ None,
846
+ )
830
847
  else:
831
- return self.backend.bk_reshape(
832
- res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
848
+ return (
849
+ self.backend.bk_reshape(
850
+ res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
851
+ ),
852
+ None,
833
853
  )
834
854
 
835
- return self.backend.bk_reshape(res, [npix // 2])
855
+ return self.backend.bk_reshape(res, [npix // 2]), None
836
856
 
837
857
  else:
838
858
  shape = list(im.shape)
859
+ if cell_ids is not None:
860
+ sim, new_cell_ids = self.backend.binned_mean(im, cell_ids)
861
+ return sim, new_cell_ids
839
862
 
840
863
  lout = int(np.sqrt(shape[axis] // 12))
841
864
  if im.__class__ == np.zeros([0]).__class__:
@@ -854,8 +877,11 @@ class FoCUS:
854
877
  if len(shape) > axis:
855
878
  oshape = oshape + shape[axis + 1 :]
856
879
 
857
- return self.backend.bk_reduce_mean(
858
- self.backend.bk_reshape(im, oshape), axis=axis + 1
880
+ return (
881
+ self.backend.bk_reduce_mean(
882
+ self.backend.bk_reshape(im, oshape), axis=axis + 1
883
+ ),
884
+ None,
859
885
  )
860
886
 
861
887
  # --------------------------------------------------------
@@ -2139,7 +2165,7 @@ class FoCUS:
2139
2165
  return self.backend.bk_reduce_sum(r)
2140
2166
 
2141
2167
  # ---------------------------------------------−---------
2142
- def convol(self, in_image, axis=0):
2168
+ def convol(self, in_image, axis=0, cell_ids=None, nside=None):
2143
2169
 
2144
2170
  image = self.backend.bk_cast(in_image)
2145
2171
 
@@ -2304,6 +2330,61 @@ class FoCUS:
2304
2330
  return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2305
2331
 
2306
2332
  else:
2333
+ ishape = list(image.shape)
2334
+
2335
+ if cell_ids is not None:
2336
+ if cell_ids.shape[0] not in self.padding_conv:
2337
+ import healpix_convolution as hc
2338
+ from xdggs.healpix import HealpixInfo
2339
+
2340
+ res = self.backend.bk_zeros(
2341
+ ishape + [self.NORIENT], dtype=self.backend.all_cbk_type
2342
+ )
2343
+
2344
+ grid_info = HealpixInfo(
2345
+ level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2346
+ )
2347
+
2348
+ for k in range(self.NORIENT):
2349
+ kernelR, kernelI = hc.kernels.wavelet_kernel(
2350
+ cell_ids, grid_info=grid_info, orientation=k, is_torch=True
2351
+ )
2352
+ self.kernelR_conv[(cell_ids.shape[0], k)] = kernelR.to(
2353
+ self.backend.all_bk_type
2354
+ ).to(image.device)
2355
+ self.kernelI_conv[(cell_ids.shape[0], k)] = kernelI.to(
2356
+ self.backend.all_bk_type
2357
+ ).to(image.device)
2358
+ self.padding_conv[(cell_ids.shape[0], k)] = hc.pad(
2359
+ cell_ids,
2360
+ grid_info=grid_info,
2361
+ ring=5 // 2, # wavelet kernel_size=5 is hard coded
2362
+ mode="mean",
2363
+ constant_value=0,
2364
+ )
2365
+
2366
+ for k in range(self.NORIENT):
2367
+
2368
+ kernelR = self.kernelR_conv[(cell_ids.shape[0], k)]
2369
+ kernelI = self.kernelI_conv[(cell_ids.shape[0], k)]
2370
+ padding = self.padding_conv[(cell_ids.shape[0], k)]
2371
+ if len(ishape) == 2:
2372
+ for l in range(ishape[0]):
2373
+ padded_data = padding.apply(image[l], is_torch=True)
2374
+ res[l, :, k] = kernelR.matmul(
2375
+ padded_data
2376
+ ) + 1j * kernelI.matmul(padded_data)
2377
+ else:
2378
+ for l in range(ishape[0]):
2379
+ for k2 in range(ishape[2]):
2380
+ padded_data = padding.apply(
2381
+ image[l, :, k2], is_torch=True
2382
+ )
2383
+ res[l, :, k2, k] = kernelR.matmul(
2384
+ padded_data
2385
+ ) + 1j * kernelI.matmul(padded_data)
2386
+ return res
2387
+
2307
2388
  nside = int(np.sqrt(image.shape[axis] // 12))
2308
2389
 
2309
2390
  if self.Idx_Neighbours[nside] is None:
@@ -2320,7 +2401,6 @@ class FoCUS:
2320
2401
  l_ww_real = self.ww_Real[nside]
2321
2402
  l_ww_imag = self.ww_Imag[nside]
2322
2403
 
2323
- ishape = list(image.shape)
2324
2404
  odata = 1
2325
2405
  for k in range(axis + 1, len(ishape)):
2326
2406
  odata = odata * ishape[k]
@@ -2474,7 +2554,7 @@ class FoCUS:
2474
2554
  return res
2475
2555
 
2476
2556
  # ---------------------------------------------−---------
2477
- def smooth(self, in_image, axis=0):
2557
+ def smooth(self, in_image, axis=0, cell_ids=None, nside=None):
2478
2558
 
2479
2559
  image = self.backend.bk_cast(in_image)
2480
2560
 
@@ -2603,6 +2683,50 @@ class FoCUS:
2603
2683
  return self.backend.bk_reshape(res, in_image.shape)
2604
2684
 
2605
2685
  else:
2686
+
2687
+ ishape = list(image.shape)
2688
+
2689
+ if cell_ids is not None:
2690
+ if cell_ids.shape[0] not in self.padding_smooth:
2691
+ import healpix_convolution as hc
2692
+ from xdggs.healpix import HealpixInfo
2693
+
2694
+ grid_info = HealpixInfo(
2695
+ level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2696
+ )
2697
+
2698
+ kernel = hc.kernels.wavelet_smooth_kernel(
2699
+ cell_ids, grid_info=grid_info, is_torch=True
2700
+ )
2701
+
2702
+ self.kernel_smooth[cell_ids.shape[0]] = kernel.to(
2703
+ self.backend.all_bk_type
2704
+ ).to(image.device)
2705
+
2706
+ self.padding_smooth[cell_ids.shape[0]] = hc.pad(
2707
+ cell_ids,
2708
+ grid_info=grid_info,
2709
+ ring=5 // 2, # wavelet kernel_size=5 is hard coded
2710
+ mode="mean",
2711
+ constant_value=0,
2712
+ )
2713
+
2714
+ kernel = self.kernel_smooth[cell_ids.shape[0]]
2715
+ padding = self.padding_smooth[cell_ids.shape[0]]
2716
+
2717
+ res = self.backend.bk_zeros(ishape, dtype=self.backend.all_cbk_type)
2718
+
2719
+ if len(ishape) == 2:
2720
+ for l in range(ishape[0]):
2721
+ padded_data = padding.apply(image[l], is_torch=True)
2722
+ res[l] = kernel.matmul(padded_data)
2723
+ else:
2724
+ for l in range(ishape[0]):
2725
+ for k2 in range(ishape[2]):
2726
+ padded_data = padding.apply(image[l, :, k2], is_torch=True)
2727
+ res[l, :, k2] = kernel.matmul(padded_data)
2728
+ return res
2729
+
2606
2730
  nside = int(np.sqrt(image.shape[axis] // 12))
2607
2731
 
2608
2732
  if self.Idx_Neighbours[nside] is None:
@@ -2618,7 +2742,6 @@ class FoCUS:
2618
2742
  self.w_smooth[nside] = ws
2619
2743
 
2620
2744
  l_w_smooth = self.w_smooth[nside]
2621
- ishape = list(image.shape)
2622
2745
 
2623
2746
  odata = 1
2624
2747
  for k in range(axis + 1, len(ishape)):
@@ -240,7 +240,7 @@ class Synthesis:
240
240
  grd_mask = self.grd_mask
241
241
 
242
242
  if grd_mask is not None:
243
- g_tot = self.operation.backend.to_numpy(g_tot*grd_mask)
243
+ g_tot = self.operation.backend.to_numpy(g_tot * grd_mask)
244
244
  else:
245
245
  g_tot = self.operation.backend.to_numpy(g_tot)
246
246