foscat 3.9.0__tar.gz → 2025.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {foscat-3.9.0/src/foscat.egg-info → foscat-2025.5.0}/PKG-INFO +3 -2
- {foscat-3.9.0 → foscat-2025.5.0}/pyproject.toml +1 -1
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/BkBase.py +9 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/BkTorch.py +68 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/FoCUS.py +157 -34
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/Synthesis.py +1 -1
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/scat_cov.py +417 -246
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/scat_cov_map.py +0 -2
- {foscat-3.9.0 → foscat-2025.5.0/src/foscat.egg-info}/PKG-INFO +3 -2
- {foscat-3.9.0 → foscat-2025.5.0}/LICENSE +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/README.md +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/setup.cfg +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/BkNumpy.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/BkTensorflow.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/CNN.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/CircSpline.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/GCNN.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/Softmax.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/Spline1D.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/__init__.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/alm.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/backend.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/backend_tens.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/loss_backend_tens.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/loss_backend_torch.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/scat.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/scat1D.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/scat2D.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/scat_cov1D.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/scat_cov2D.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat/scat_cov_map2D.py +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat.egg-info/SOURCES.txt +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat.egg-info/dependency_links.txt +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat.egg-info/requires.txt +0 -0
- {foscat-3.9.0 → foscat-2025.5.0}/src/foscat.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: foscat
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2025.5.0
|
|
4
4
|
Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
|
|
5
5
|
Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
|
|
6
6
|
Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
|
|
@@ -25,6 +25,7 @@ Requires-Dist: matplotlib
|
|
|
25
25
|
Requires-Dist: numpy
|
|
26
26
|
Requires-Dist: healpy
|
|
27
27
|
Requires-Dist: spherical
|
|
28
|
+
Dynamic: license-file
|
|
28
29
|
|
|
29
30
|
# foscat
|
|
30
31
|
|
|
@@ -23,6 +23,15 @@ class BackendBase:
|
|
|
23
23
|
self._fft_3_orient = {}
|
|
24
24
|
self._fft_3_orient_C = {}
|
|
25
25
|
|
|
26
|
+
def to_dict(self):
|
|
27
|
+
return {
|
|
28
|
+
"name": self.BACKEND,
|
|
29
|
+
"mpi_rank": self.mpi_rank,
|
|
30
|
+
"all_type": self.all_type,
|
|
31
|
+
"gpupos": self.gpupos,
|
|
32
|
+
"silent": self.silent,
|
|
33
|
+
}
|
|
34
|
+
|
|
26
35
|
def iso_mean(self, x, use_2D=False):
|
|
27
36
|
shape = list(x.shape)
|
|
28
37
|
|
|
@@ -62,6 +62,74 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
62
62
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
63
63
|
)
|
|
64
64
|
|
|
65
|
+
def binned_mean(self, data, cell_ids):
|
|
66
|
+
"""
|
|
67
|
+
data: Tensor of shape [B, N, A]
|
|
68
|
+
I: Tensor of shape [N], integer indices in [0, n_bins)
|
|
69
|
+
Returns: mean per bin, shape [B, n_bins, A]
|
|
70
|
+
"""
|
|
71
|
+
groups = cell_ids // 4 # [N]
|
|
72
|
+
|
|
73
|
+
unique_groups, I = np.unique(groups, return_inverse=True)
|
|
74
|
+
|
|
75
|
+
n_bins = unique_groups.shape[0]
|
|
76
|
+
|
|
77
|
+
B = data.shape[0]
|
|
78
|
+
|
|
79
|
+
counts = torch.bincount(torch.tensor(I).to(data.device))[None, :]
|
|
80
|
+
|
|
81
|
+
I = np.tile(I, B) + np.tile(n_bins * np.arange(B, dtype="int"), data.shape[1])
|
|
82
|
+
|
|
83
|
+
if len(data.shape) == 3:
|
|
84
|
+
A = data.shape[2]
|
|
85
|
+
I = np.repeat(I, A) * A + np.repeat(
|
|
86
|
+
np.arange(A, dtype="int"), data.shape[1] * B
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
I = torch.tensor(I).to(data.device)
|
|
90
|
+
|
|
91
|
+
# Comptage par bin
|
|
92
|
+
if len(data.shape) == 2:
|
|
93
|
+
sum_per_bin = torch.zeros(
|
|
94
|
+
[B * n_bins], dtype=data.dtype, device=data.device
|
|
95
|
+
)
|
|
96
|
+
sum_per_bin = sum_per_bin.scatter_add(
|
|
97
|
+
0, I, self.bk_reshape(data, B * data.shape[1])
|
|
98
|
+
).reshape(B, n_bins)
|
|
99
|
+
|
|
100
|
+
mean_per_bin = sum_per_bin / counts # [B, n_bins, A]
|
|
101
|
+
else:
|
|
102
|
+
sum_per_bin = torch.zeros(
|
|
103
|
+
[B * n_bins * A], dtype=data.dtype, device=data.device
|
|
104
|
+
)
|
|
105
|
+
sum_per_bin = sum_per_bin.scatter_add(
|
|
106
|
+
0, I, self.bk_reshape(data, B * data.shape[1] * A)
|
|
107
|
+
).reshape(
|
|
108
|
+
B, n_bins, A
|
|
109
|
+
) # [B, n_bins]
|
|
110
|
+
|
|
111
|
+
mean_per_bin = sum_per_bin / counts[:, :, None] # [B, n_bins, A]
|
|
112
|
+
|
|
113
|
+
return mean_per_bin, unique_groups
|
|
114
|
+
|
|
115
|
+
def average_by_cell_group(data, cell_ids):
|
|
116
|
+
"""
|
|
117
|
+
data: tensor of shape [..., N, ...] (ex: [B, N, C])
|
|
118
|
+
cell_ids: tensor of shape [N]
|
|
119
|
+
Returns: mean_data of shape [..., G, ...] where G = number of unique cell_ids//4
|
|
120
|
+
"""
|
|
121
|
+
original_shape = data.shape
|
|
122
|
+
leading = data.shape[:-2] # all dims before N
|
|
123
|
+
N = data.shape[-2]
|
|
124
|
+
trailing = data.shape[-1:] # all dims after N
|
|
125
|
+
|
|
126
|
+
groups = (cell_ids // 4).long() # [N]
|
|
127
|
+
unique_groups, group_indices, counts = torch.unique(
|
|
128
|
+
groups, return_inverse=True, return_counts=True
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return torch.bincount(group_indices, weights=data) / counts, unique_groups
|
|
132
|
+
|
|
65
133
|
# ---------------------------------------------−---------
|
|
66
134
|
# -- BACKEND DEFINITION --
|
|
67
135
|
# ---------------------------------------------−---------
|
|
@@ -35,7 +35,7 @@ class FoCUS:
|
|
|
35
35
|
mpi_rank=0,
|
|
36
36
|
):
|
|
37
37
|
|
|
38
|
-
self.__version__ = "
|
|
38
|
+
self.__version__ = "2025.05.0"
|
|
39
39
|
# P00 coeff for normalization for scat_cov
|
|
40
40
|
self.TMPFILE_VERSION = TMPFILE_VERSION
|
|
41
41
|
self.P1_dic = None
|
|
@@ -50,6 +50,12 @@ class FoCUS:
|
|
|
50
50
|
self.return_data = return_data
|
|
51
51
|
self.silent = silent
|
|
52
52
|
|
|
53
|
+
self.kernel_smooth = {}
|
|
54
|
+
self.padding_smooth = {}
|
|
55
|
+
self.kernelR_conv = {}
|
|
56
|
+
self.kernelI_conv = {}
|
|
57
|
+
self.padding_conv = {}
|
|
58
|
+
|
|
53
59
|
if not self.silent:
|
|
54
60
|
print("================================================")
|
|
55
61
|
print(" START FOSCAT CONFIGURATION")
|
|
@@ -68,10 +74,7 @@ class FoCUS:
|
|
|
68
74
|
if not self.silent:
|
|
69
75
|
print("The directory %s is created")
|
|
70
76
|
except:
|
|
71
|
-
|
|
72
|
-
print(
|
|
73
|
-
"Impossible to create the directory %s" % (self.TEMPLATE_PATH)
|
|
74
|
-
)
|
|
77
|
+
print("Impossible to create the directory %s" % (self.TEMPLATE_PATH))
|
|
75
78
|
return None
|
|
76
79
|
|
|
77
80
|
self.number_of_loss = 0
|
|
@@ -81,10 +84,9 @@ class FoCUS:
|
|
|
81
84
|
self.padding = padding
|
|
82
85
|
|
|
83
86
|
if JmaxDelta != 0:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
)
|
|
87
|
+
print(
|
|
88
|
+
"OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
|
|
89
|
+
)
|
|
88
90
|
return None
|
|
89
91
|
|
|
90
92
|
self.OSTEP = JmaxDelta
|
|
@@ -742,14 +744,14 @@ class FoCUS:
|
|
|
742
744
|
return rim
|
|
743
745
|
|
|
744
746
|
# --------------------------------------------------------
|
|
745
|
-
def ud_grade_2(self, im, axis=0):
|
|
747
|
+
def ud_grade_2(self, im, axis=0, cell_ids=None, nside=None):
|
|
746
748
|
|
|
747
749
|
if self.use_2D:
|
|
748
750
|
ishape = list(im.shape)
|
|
749
751
|
if len(ishape) < axis + 2:
|
|
750
752
|
if not self.silent:
|
|
751
753
|
print("Use of 2D scat with data that has less than 2D")
|
|
752
|
-
return None
|
|
754
|
+
return None, None
|
|
753
755
|
|
|
754
756
|
npix = im.shape[axis]
|
|
755
757
|
npiy = im.shape[axis + 1]
|
|
@@ -774,29 +776,40 @@ class FoCUS:
|
|
|
774
776
|
|
|
775
777
|
if axis == 0:
|
|
776
778
|
if len(ishape) == 2:
|
|
777
|
-
return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
|
|
779
|
+
return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
|
|
778
780
|
else:
|
|
779
|
-
return
|
|
780
|
-
|
|
781
|
+
return (
|
|
782
|
+
self.backend.bk_reshape(
|
|
783
|
+
res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
|
|
784
|
+
),
|
|
785
|
+
None,
|
|
781
786
|
)
|
|
782
787
|
else:
|
|
783
788
|
if len(ishape) == axis + 2:
|
|
784
|
-
return
|
|
785
|
-
|
|
789
|
+
return (
|
|
790
|
+
self.backend.bk_reshape(
|
|
791
|
+
res, ishape[0:axis] + [npix // 2, npiy // 2]
|
|
792
|
+
),
|
|
793
|
+
None,
|
|
786
794
|
)
|
|
787
795
|
else:
|
|
788
|
-
return
|
|
789
|
-
|
|
790
|
-
|
|
796
|
+
return (
|
|
797
|
+
self.backend.bk_reshape(
|
|
798
|
+
res,
|
|
799
|
+
ishape[0:axis]
|
|
800
|
+
+ [npix // 2, npiy // 2]
|
|
801
|
+
+ ishape[axis + 2 :],
|
|
802
|
+
),
|
|
803
|
+
None,
|
|
791
804
|
)
|
|
792
805
|
|
|
793
|
-
return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
|
|
806
|
+
return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
|
|
794
807
|
elif self.use_1D:
|
|
795
808
|
ishape = list(im.shape)
|
|
796
809
|
if len(ishape) < axis + 1:
|
|
797
810
|
if not self.silent:
|
|
798
811
|
print("Use of 1D scat with data that has less than 1D")
|
|
799
|
-
return None
|
|
812
|
+
return None, None
|
|
800
813
|
|
|
801
814
|
npix = im.shape[axis]
|
|
802
815
|
odata = 1
|
|
@@ -819,23 +832,33 @@ class FoCUS:
|
|
|
819
832
|
|
|
820
833
|
if axis == 0:
|
|
821
834
|
if len(ishape) == 1:
|
|
822
|
-
return self.backend.bk_reshape(res, [npix // 2])
|
|
835
|
+
return self.backend.bk_reshape(res, [npix // 2]), None
|
|
823
836
|
else:
|
|
824
|
-
return
|
|
825
|
-
res, [npix // 2] + ishape[axis + 1 :]
|
|
837
|
+
return (
|
|
838
|
+
self.backend.bk_reshape(res, [npix // 2] + ishape[axis + 1 :]),
|
|
839
|
+
None,
|
|
826
840
|
)
|
|
827
841
|
else:
|
|
828
842
|
if len(ishape) == axis + 1:
|
|
829
|
-
return
|
|
843
|
+
return (
|
|
844
|
+
self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2]),
|
|
845
|
+
None,
|
|
846
|
+
)
|
|
830
847
|
else:
|
|
831
|
-
return
|
|
832
|
-
|
|
848
|
+
return (
|
|
849
|
+
self.backend.bk_reshape(
|
|
850
|
+
res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
|
|
851
|
+
),
|
|
852
|
+
None,
|
|
833
853
|
)
|
|
834
854
|
|
|
835
|
-
return self.backend.bk_reshape(res, [npix // 2])
|
|
855
|
+
return self.backend.bk_reshape(res, [npix // 2]), None
|
|
836
856
|
|
|
837
857
|
else:
|
|
838
858
|
shape = list(im.shape)
|
|
859
|
+
if cell_ids is not None:
|
|
860
|
+
sim, new_cell_ids = self.backend.binned_mean(im, cell_ids)
|
|
861
|
+
return sim, new_cell_ids
|
|
839
862
|
|
|
840
863
|
lout = int(np.sqrt(shape[axis] // 12))
|
|
841
864
|
if im.__class__ == np.zeros([0]).__class__:
|
|
@@ -854,8 +877,11 @@ class FoCUS:
|
|
|
854
877
|
if len(shape) > axis:
|
|
855
878
|
oshape = oshape + shape[axis + 1 :]
|
|
856
879
|
|
|
857
|
-
return
|
|
858
|
-
self.backend.
|
|
880
|
+
return (
|
|
881
|
+
self.backend.bk_reduce_mean(
|
|
882
|
+
self.backend.bk_reshape(im, oshape), axis=axis + 1
|
|
883
|
+
),
|
|
884
|
+
None,
|
|
859
885
|
)
|
|
860
886
|
|
|
861
887
|
# --------------------------------------------------------
|
|
@@ -2139,7 +2165,7 @@ class FoCUS:
|
|
|
2139
2165
|
return self.backend.bk_reduce_sum(r)
|
|
2140
2166
|
|
|
2141
2167
|
# ---------------------------------------------−---------
|
|
2142
|
-
def convol(self, in_image, axis=0):
|
|
2168
|
+
def convol(self, in_image, axis=0, cell_ids=None, nside=None):
|
|
2143
2169
|
|
|
2144
2170
|
image = self.backend.bk_cast(in_image)
|
|
2145
2171
|
|
|
@@ -2304,6 +2330,61 @@ class FoCUS:
|
|
|
2304
2330
|
return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
|
|
2305
2331
|
|
|
2306
2332
|
else:
|
|
2333
|
+
ishape = list(image.shape)
|
|
2334
|
+
|
|
2335
|
+
if cell_ids is not None:
|
|
2336
|
+
if cell_ids.shape[0] not in self.padding_conv:
|
|
2337
|
+
import healpix_convolution as hc
|
|
2338
|
+
from xdggs.healpix import HealpixInfo
|
|
2339
|
+
|
|
2340
|
+
res = self.backend.bk_zeros(
|
|
2341
|
+
ishape + [self.NORIENT], dtype=self.backend.all_cbk_type
|
|
2342
|
+
)
|
|
2343
|
+
|
|
2344
|
+
grid_info = HealpixInfo(
|
|
2345
|
+
level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
|
|
2346
|
+
)
|
|
2347
|
+
|
|
2348
|
+
for k in range(self.NORIENT):
|
|
2349
|
+
kernelR, kernelI = hc.kernels.wavelet_kernel(
|
|
2350
|
+
cell_ids, grid_info=grid_info, orientation=k, is_torch=True
|
|
2351
|
+
)
|
|
2352
|
+
self.kernelR_conv[(cell_ids.shape[0], k)] = kernelR.to(
|
|
2353
|
+
self.backend.all_bk_type
|
|
2354
|
+
).to(image.device)
|
|
2355
|
+
self.kernelI_conv[(cell_ids.shape[0], k)] = kernelI.to(
|
|
2356
|
+
self.backend.all_bk_type
|
|
2357
|
+
).to(image.device)
|
|
2358
|
+
self.padding_conv[(cell_ids.shape[0], k)] = hc.pad(
|
|
2359
|
+
cell_ids,
|
|
2360
|
+
grid_info=grid_info,
|
|
2361
|
+
ring=5 // 2, # wavelet kernel_size=5 is hard coded
|
|
2362
|
+
mode="mean",
|
|
2363
|
+
constant_value=0,
|
|
2364
|
+
)
|
|
2365
|
+
|
|
2366
|
+
for k in range(self.NORIENT):
|
|
2367
|
+
|
|
2368
|
+
kernelR = self.kernelR_conv[(cell_ids.shape[0], k)]
|
|
2369
|
+
kernelI = self.kernelI_conv[(cell_ids.shape[0], k)]
|
|
2370
|
+
padding = self.padding_conv[(cell_ids.shape[0], k)]
|
|
2371
|
+
if len(ishape) == 2:
|
|
2372
|
+
for l in range(ishape[0]):
|
|
2373
|
+
padded_data = padding.apply(image[l], is_torch=True)
|
|
2374
|
+
res[l, :, k] = kernelR.matmul(
|
|
2375
|
+
padded_data
|
|
2376
|
+
) + 1j * kernelI.matmul(padded_data)
|
|
2377
|
+
else:
|
|
2378
|
+
for l in range(ishape[0]):
|
|
2379
|
+
for k2 in range(ishape[2]):
|
|
2380
|
+
padded_data = padding.apply(
|
|
2381
|
+
image[l, :, k2], is_torch=True
|
|
2382
|
+
)
|
|
2383
|
+
res[l, :, k2, k] = kernelR.matmul(
|
|
2384
|
+
padded_data
|
|
2385
|
+
) + 1j * kernelI.matmul(padded_data)
|
|
2386
|
+
return res
|
|
2387
|
+
|
|
2307
2388
|
nside = int(np.sqrt(image.shape[axis] // 12))
|
|
2308
2389
|
|
|
2309
2390
|
if self.Idx_Neighbours[nside] is None:
|
|
@@ -2320,7 +2401,6 @@ class FoCUS:
|
|
|
2320
2401
|
l_ww_real = self.ww_Real[nside]
|
|
2321
2402
|
l_ww_imag = self.ww_Imag[nside]
|
|
2322
2403
|
|
|
2323
|
-
ishape = list(image.shape)
|
|
2324
2404
|
odata = 1
|
|
2325
2405
|
for k in range(axis + 1, len(ishape)):
|
|
2326
2406
|
odata = odata * ishape[k]
|
|
@@ -2474,7 +2554,7 @@ class FoCUS:
|
|
|
2474
2554
|
return res
|
|
2475
2555
|
|
|
2476
2556
|
# ---------------------------------------------−---------
|
|
2477
|
-
def smooth(self, in_image, axis=0):
|
|
2557
|
+
def smooth(self, in_image, axis=0, cell_ids=None, nside=None):
|
|
2478
2558
|
|
|
2479
2559
|
image = self.backend.bk_cast(in_image)
|
|
2480
2560
|
|
|
@@ -2603,6 +2683,50 @@ class FoCUS:
|
|
|
2603
2683
|
return self.backend.bk_reshape(res, in_image.shape)
|
|
2604
2684
|
|
|
2605
2685
|
else:
|
|
2686
|
+
|
|
2687
|
+
ishape = list(image.shape)
|
|
2688
|
+
|
|
2689
|
+
if cell_ids is not None:
|
|
2690
|
+
if cell_ids.shape[0] not in self.padding_smooth:
|
|
2691
|
+
import healpix_convolution as hc
|
|
2692
|
+
from xdggs.healpix import HealpixInfo
|
|
2693
|
+
|
|
2694
|
+
grid_info = HealpixInfo(
|
|
2695
|
+
level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
|
|
2696
|
+
)
|
|
2697
|
+
|
|
2698
|
+
kernel = hc.kernels.wavelet_smooth_kernel(
|
|
2699
|
+
cell_ids, grid_info=grid_info, is_torch=True
|
|
2700
|
+
)
|
|
2701
|
+
|
|
2702
|
+
self.kernel_smooth[cell_ids.shape[0]] = kernel.to(
|
|
2703
|
+
self.backend.all_bk_type
|
|
2704
|
+
).to(image.device)
|
|
2705
|
+
|
|
2706
|
+
self.padding_smooth[cell_ids.shape[0]] = hc.pad(
|
|
2707
|
+
cell_ids,
|
|
2708
|
+
grid_info=grid_info,
|
|
2709
|
+
ring=5 // 2, # wavelet kernel_size=5 is hard coded
|
|
2710
|
+
mode="mean",
|
|
2711
|
+
constant_value=0,
|
|
2712
|
+
)
|
|
2713
|
+
|
|
2714
|
+
kernel = self.kernel_smooth[cell_ids.shape[0]]
|
|
2715
|
+
padding = self.padding_smooth[cell_ids.shape[0]]
|
|
2716
|
+
|
|
2717
|
+
res = self.backend.bk_zeros(ishape, dtype=self.backend.all_cbk_type)
|
|
2718
|
+
|
|
2719
|
+
if len(ishape) == 2:
|
|
2720
|
+
for l in range(ishape[0]):
|
|
2721
|
+
padded_data = padding.apply(image[l], is_torch=True)
|
|
2722
|
+
res[l] = kernel.matmul(padded_data)
|
|
2723
|
+
else:
|
|
2724
|
+
for l in range(ishape[0]):
|
|
2725
|
+
for k2 in range(ishape[2]):
|
|
2726
|
+
padded_data = padding.apply(image[l, :, k2], is_torch=True)
|
|
2727
|
+
res[l, :, k2] = kernel.matmul(padded_data)
|
|
2728
|
+
return res
|
|
2729
|
+
|
|
2606
2730
|
nside = int(np.sqrt(image.shape[axis] // 12))
|
|
2607
2731
|
|
|
2608
2732
|
if self.Idx_Neighbours[nside] is None:
|
|
@@ -2618,7 +2742,6 @@ class FoCUS:
|
|
|
2618
2742
|
self.w_smooth[nside] = ws
|
|
2619
2743
|
|
|
2620
2744
|
l_w_smooth = self.w_smooth[nside]
|
|
2621
|
-
ishape = list(image.shape)
|
|
2622
2745
|
|
|
2623
2746
|
odata = 1
|
|
2624
2747
|
for k in range(axis + 1, len(ishape)):
|
|
@@ -240,7 +240,7 @@ class Synthesis:
|
|
|
240
240
|
grd_mask = self.grd_mask
|
|
241
241
|
|
|
242
242
|
if grd_mask is not None:
|
|
243
|
-
g_tot = self.operation.backend.to_numpy(g_tot*grd_mask)
|
|
243
|
+
g_tot = self.operation.backend.to_numpy(g_tot * grd_mask)
|
|
244
244
|
else:
|
|
245
245
|
g_tot = self.operation.backend.to_numpy(g_tot)
|
|
246
246
|
|