foscat 2025.5.0__py3-none-any.whl → 2025.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTensorflow.py +138 -14
- foscat/BkTorch.py +90 -57
- foscat/CNN.py +31 -30
- foscat/FoCUS.py +640 -917
- foscat/GCNN.py +48 -150
- foscat/Softmax.py +1 -0
- foscat/alm.py +2 -2
- foscat/heal_NN.py +432 -0
- foscat/scat_cov.py +139 -96
- foscat/scat_cov_map2D.py +2 -2
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/METADATA +1 -1
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/RECORD +15 -14
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/WHEEL +1 -1
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/top_level.txt +0 -0
foscat/BkTensorflow.py
CHANGED
|
@@ -68,10 +68,16 @@ class BkTensorflow(BackendBase.BackendBase):
|
|
|
68
68
|
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
|
|
69
69
|
sys.stdout.flush()
|
|
70
70
|
self.ngpu = len(logical_gpus)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
71
|
+
if self.ngpu > 0:
|
|
72
|
+
gpuname = logical_gpus[self.gpupos % self.ngpu].name
|
|
73
|
+
self.gpulist = {}
|
|
74
|
+
for i in range(self.ngpu):
|
|
75
|
+
self.gpulist[i] = logical_gpus[i].name
|
|
76
|
+
else:
|
|
77
|
+
gpuname = "CPU:0"
|
|
78
|
+
self.gpulist = {}
|
|
79
|
+
self.gpulist[0] = gpuname
|
|
80
|
+
self.ngpu = 1
|
|
75
81
|
|
|
76
82
|
except RuntimeError as e:
|
|
77
83
|
# Memory growth must be set before GPUs have been initialized
|
|
@@ -122,18 +128,136 @@ class BkTensorflow(BackendBase.BackendBase):
|
|
|
122
128
|
|
|
123
129
|
return x_padded
|
|
124
130
|
|
|
125
|
-
def
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
131
|
+
def binned_mean(self, data, cell_ids):
|
|
132
|
+
"""
|
|
133
|
+
data: Tensor of shape [..., N] (float32 or float64)
|
|
134
|
+
cell_ids: Tensor of shape [N], int indices in [0, n_bins)
|
|
135
|
+
Returns: mean per bin, shape [..., n_bins]
|
|
136
|
+
"""
|
|
137
|
+
ishape = list(data.shape)
|
|
138
|
+
A = 1
|
|
139
|
+
for k in range(len(ishape) - 1):
|
|
140
|
+
A *= ishape[k]
|
|
141
|
+
N = tf.shape(data)[-1]
|
|
142
|
+
|
|
143
|
+
# Step 1: group indices
|
|
144
|
+
groups = tf.math.floordiv(cell_ids, 4) # [N]
|
|
145
|
+
unique_groups, I = tf.unique(groups) # I: [N]
|
|
146
|
+
n_bins = tf.shape(unique_groups)[0]
|
|
147
|
+
|
|
148
|
+
# Step 2: build I_tiled with batch + channel offsets
|
|
149
|
+
I_tiled = tf.tile(I[None, :], [A, 1]) # shape [, N]
|
|
150
|
+
|
|
151
|
+
# Offset index to flatten across [A, n_bins]
|
|
152
|
+
batch_channel_offsets = tf.range(A)[:, None] * n_bins
|
|
153
|
+
I_offset = I_tiled + batch_channel_offsets # shape [A, N]]
|
|
154
|
+
|
|
155
|
+
# Step 3: flatten data to shape [A, N]
|
|
156
|
+
data_reshaped = tf.reshape(data, [A, N]) # shape [A, N]
|
|
157
|
+
|
|
158
|
+
# Flatten all for scatter_nd
|
|
159
|
+
indices = tf.reshape(I_offset, [-1]) # [A*N]
|
|
160
|
+
values = tf.reshape(data_reshaped, [-1]) # [A*N]
|
|
161
|
+
|
|
162
|
+
"""
|
|
163
|
+
# Prepare for scatter: indices → [A*N, 1]
|
|
164
|
+
scatter_indices = tf.expand_dims(indices, axis=1)
|
|
165
|
+
scatter_indices = tf.cast(scatter_indices, tf.int64)
|
|
166
|
+
"""
|
|
167
|
+
total_bins = A * n_bins
|
|
168
|
+
|
|
169
|
+
# Step 4: sum per bin
|
|
170
|
+
sum_per_bin = tf.math.unsorted_segment_sum(values, indices, total_bins)
|
|
171
|
+
sum_per_bin = tf.reshape(sum_per_bin, ishape[0:-1] + [n_bins]) # [A, n_bins]
|
|
172
|
+
|
|
173
|
+
# Step 5: count per bin (same indices)
|
|
174
|
+
counts = tf.math.unsorted_segment_sum(1.0 + 0 * values, indices, total_bins)
|
|
175
|
+
# counts = tf.math.bincount(indices, minlength=total_bins, maxlength=total_bins)
|
|
176
|
+
counts = tf.reshape(counts, ishape[0:-1] + [n_bins])
|
|
177
|
+
# counts = tf.maximum(counts, 1) # Avoid division by zero
|
|
178
|
+
# counts = tf.cast(counts, dtype=data.dtype)
|
|
179
|
+
|
|
180
|
+
# Step 6: mean
|
|
181
|
+
mean_per_bin = sum_per_bin / counts # [B, A, n_bins]
|
|
182
|
+
|
|
183
|
+
return mean_per_bin, unique_groups
|
|
184
|
+
|
|
185
|
+
def conv2d(self, x, w):
|
|
186
|
+
"""
|
|
187
|
+
Perform 2D convolution using TensorFlow.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
x: Tensor of shape [..., Nx, Ny] – input
|
|
191
|
+
w: Tensor of shape [O_c, wx, wy] – conv weights
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Tensor of shape [..., O_c, Nx, Ny]
|
|
195
|
+
"""
|
|
196
|
+
# Extract shape
|
|
197
|
+
*leading_dims, Nx, Ny = x.shape
|
|
198
|
+
O_c, wx, wy = w.shape
|
|
199
|
+
|
|
200
|
+
# Flatten leading dims into a batch dimension
|
|
201
|
+
B = tf.reduce_prod(leading_dims) if leading_dims else 1
|
|
202
|
+
x = tf.reshape(x, [B, Nx, Ny, 1]) # TensorFlow format: [B, H, W, C_in=1]
|
|
203
|
+
|
|
204
|
+
# Reshape weights to [wx, wy, in_channels=1, out_channels]
|
|
205
|
+
w = tf.reshape(w, [O_c, wx, wy])
|
|
206
|
+
w = tf.transpose(w, perm=[1, 2, 0]) # [wx, wy, O_c]
|
|
207
|
+
w = tf.reshape(w, [wx, wy, 1, O_c]) # [wx, wy, C_in=1, C_out]
|
|
208
|
+
|
|
209
|
+
# Apply 'reflect' padding manually
|
|
210
|
+
pad_x = wx // 2
|
|
211
|
+
pad_y = wy // 2
|
|
212
|
+
x_padded = tf.pad(
|
|
213
|
+
x, [[0, 0], [pad_x, pad_x], [pad_y, pad_y], [0, 0]], mode="REFLECT"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Perform convolution
|
|
217
|
+
y = tf.nn.conv2d(
|
|
218
|
+
x_padded, w, strides=[1, 1, 1, 1], padding="VALID"
|
|
219
|
+
) # [B, Nx, Ny, O_c]
|
|
220
|
+
|
|
221
|
+
# Transpose back to match original format: [..., O_c, Nx, Ny]
|
|
222
|
+
y = tf.transpose(y, [0, 3, 1, 2]) # [B, O_c, Nx, Ny]
|
|
223
|
+
y = tf.reshape(y, [*leading_dims, O_c, Nx, Ny])
|
|
224
|
+
|
|
225
|
+
return y
|
|
226
|
+
|
|
227
|
+
def conv1d(self, x, w):
|
|
228
|
+
"""
|
|
229
|
+
Perform 1D convolution using TensorFlow.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
x: Tensor of shape [..., N] – input
|
|
233
|
+
w: Tensor of shape [k] – conv weights
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Tensor of shape [...,N]
|
|
237
|
+
"""
|
|
238
|
+
# Extract shapes
|
|
239
|
+
*leading_dims, N = x.shape
|
|
240
|
+
k = w.shape[0]
|
|
241
|
+
|
|
242
|
+
# Flatten leading dims into batch dimension
|
|
243
|
+
B = tf.reduce_prod(leading_dims) if leading_dims else 1
|
|
244
|
+
x = tf.reshape(x, [B, N, 1]) # TensorFlow 1D format: [B, L, C=1]
|
|
245
|
+
|
|
246
|
+
# Prepare weights: [k, in_channels=1, out_channels=O_c]
|
|
247
|
+
w = tf.reshape(w, [k, 1, 1])
|
|
248
|
+
|
|
249
|
+
# Apply 'reflect' padding
|
|
250
|
+
pad = k // 2
|
|
251
|
+
x_padded = tf.pad(x, [[0, 0], [pad, pad], [0, 0]], mode="REFLECT")
|
|
252
|
+
|
|
253
|
+
# Perform convolution
|
|
254
|
+
y = tf.nn.conv1d(x_padded, w, stride=1, padding="VALID") # [B, N, O_c]
|
|
130
255
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
tmp = self.backend.pad(x, paddings, "SYMMETRIC")
|
|
256
|
+
# Transpose to [B, O_c, N] and reshape back
|
|
257
|
+
y = tf.transpose(y, [0, 2, 1]) # [B, 1, N]
|
|
258
|
+
y = tf.reshape(y, [*leading_dims, N]) # [..., N]
|
|
135
259
|
|
|
136
|
-
return
|
|
260
|
+
return y
|
|
137
261
|
|
|
138
262
|
def bk_threshold(self, x, threshold, greater=True):
|
|
139
263
|
|
foscat/BkTorch.py
CHANGED
|
@@ -2,6 +2,7 @@ import sys
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
5
6
|
|
|
6
7
|
import foscat.BkBase as BackendBase
|
|
7
8
|
|
|
@@ -62,55 +63,69 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
62
63
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
63
64
|
)
|
|
64
65
|
|
|
66
|
+
import torch
|
|
67
|
+
|
|
65
68
|
def binned_mean(self, data, cell_ids):
|
|
66
69
|
"""
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
+
Compute the mean over groups of 4 nested HEALPix cells (nside → nside/2).
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
data (torch.Tensor): Tensor of shape [..., N], where N is the number of HEALPix cells.
|
|
74
|
+
cell_ids (torch.LongTensor): Tensor of shape [N], with cell indices (nested ordering).
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
torch.Tensor: Tensor of shape [..., n_bins], with averaged values per group of 4 cells.
|
|
70
78
|
"""
|
|
71
|
-
|
|
79
|
+
if isinstance(data, np.ndarray):
|
|
80
|
+
data = torch.from_numpy(data).to(
|
|
81
|
+
dtype=torch.float32, device=self.torch_device
|
|
82
|
+
)
|
|
83
|
+
if isinstance(cell_ids, np.ndarray):
|
|
84
|
+
cell_ids = torch.from_numpy(cell_ids).to(
|
|
85
|
+
dtype=torch.long, device=self.torch_device
|
|
86
|
+
)
|
|
72
87
|
|
|
73
|
-
|
|
88
|
+
# Compute supercell ids by grouping 4 nested cells together
|
|
89
|
+
groups = cell_ids // 4
|
|
74
90
|
|
|
91
|
+
# Get unique group ids and inverse mapping
|
|
92
|
+
unique_groups, inverse_indices = torch.unique(groups, return_inverse=True)
|
|
75
93
|
n_bins = unique_groups.shape[0]
|
|
76
94
|
|
|
77
|
-
|
|
95
|
+
# Flatten all leading dimensions into a single batch dimension
|
|
96
|
+
original_shape = data.shape[:-1]
|
|
97
|
+
N = data.shape[-1]
|
|
98
|
+
data_flat = data.reshape(-1, N) # Shape: [B, N]
|
|
78
99
|
|
|
79
|
-
|
|
100
|
+
# Prepare to compute sums using scatter_add
|
|
101
|
+
B = data_flat.shape[0]
|
|
80
102
|
|
|
81
|
-
|
|
103
|
+
# Repeat inverse indices for each batch element
|
|
104
|
+
idx = inverse_indices.repeat(B, 1) # Shape: [B, N]
|
|
82
105
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
np.arange(A, dtype="int"), data.shape[1] * B
|
|
87
|
-
)
|
|
106
|
+
# Offset indices to simulate a per-batch scatter into [B * n_bins]
|
|
107
|
+
batch_offsets = torch.arange(B, device=data.device).unsqueeze(1) * n_bins
|
|
108
|
+
idx_offset = idx + batch_offsets # Shape: [B, N]
|
|
88
109
|
|
|
89
|
-
|
|
110
|
+
# Flatten everything for scatter
|
|
111
|
+
idx_offset_flat = idx_offset.flatten()
|
|
112
|
+
data_flat_flat = data_flat.flatten()
|
|
90
113
|
|
|
91
|
-
#
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
[B * n_bins], dtype=data.dtype, device=data.device
|
|
95
|
-
)
|
|
96
|
-
sum_per_bin = sum_per_bin.scatter_add(
|
|
97
|
-
0, I, self.bk_reshape(data, B * data.shape[1])
|
|
98
|
-
).reshape(B, n_bins)
|
|
114
|
+
# Accumulate sums per bin
|
|
115
|
+
out = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
|
|
116
|
+
out = out.scatter_add(0, idx_offset_flat, data_flat_flat)
|
|
99
117
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
)
|
|
105
|
-
sum_per_bin = sum_per_bin.scatter_add(
|
|
106
|
-
0, I, self.bk_reshape(data, B * data.shape[1] * A)
|
|
107
|
-
).reshape(
|
|
108
|
-
B, n_bins, A
|
|
109
|
-
) # [B, n_bins]
|
|
118
|
+
# Count number of elements per bin (to compute mean)
|
|
119
|
+
ones = torch.ones_like(data_flat_flat)
|
|
120
|
+
counts = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
|
|
121
|
+
counts = counts.scatter_add(0, idx_offset_flat, ones)
|
|
110
122
|
|
|
111
|
-
|
|
123
|
+
# Compute mean
|
|
124
|
+
mean = out / counts # Shape: [B * n_bins]
|
|
125
|
+
mean = mean.view(B, n_bins)
|
|
112
126
|
|
|
113
|
-
|
|
127
|
+
# Restore original leading dimensions
|
|
128
|
+
return mean.view(*original_shape, n_bins), unique_groups
|
|
114
129
|
|
|
115
130
|
def average_by_cell_group(data, cell_ids):
|
|
116
131
|
"""
|
|
@@ -134,11 +149,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
134
149
|
# -- BACKEND DEFINITION --
|
|
135
150
|
# ---------------------------------------------−---------
|
|
136
151
|
def bk_SparseTensor(self, indice, w, dense_shape=[]):
|
|
137
|
-
return (
|
|
138
|
-
self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
|
|
139
|
-
.to_sparse_csr()
|
|
140
|
-
.to(self.torch_device)
|
|
141
|
-
)
|
|
152
|
+
return self.backend.sparse_coo_tensor(indice.T, w, dense_shape).to_sparse_csr().to(self.torch_device)
|
|
142
153
|
|
|
143
154
|
def bk_stack(self, list, axis=0):
|
|
144
155
|
return self.backend.stack(list, axis=axis).to(self.torch_device)
|
|
@@ -146,20 +157,40 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
146
157
|
def bk_sparse_dense_matmul(self, smat, mat):
|
|
147
158
|
return smat.matmul(mat)
|
|
148
159
|
|
|
149
|
-
def conv2d(self, x, w
|
|
150
|
-
|
|
160
|
+
def conv2d(self, x, w):
|
|
161
|
+
"""
|
|
162
|
+
Perform 2D convolution using PyTorch format.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
x: Tensor of shape [..., Nx, Ny] – input
|
|
166
|
+
w: Tensor of shape [O_c, wx, wy] – conv weights
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Tensor of shape [..., O_c, Nx, Ny]
|
|
170
|
+
"""
|
|
171
|
+
*leading_dims, Nx, Ny = x.shape # extract leading dims
|
|
172
|
+
O_c, wx, wy = w.shape
|
|
173
|
+
|
|
174
|
+
# Flatten leading dims into batch dimension
|
|
175
|
+
B = int(torch.prod(torch.tensor(leading_dims))) if leading_dims else 1
|
|
176
|
+
x = x.reshape(B, 1, Nx, Ny) # [B, 1, Nx, Ny]
|
|
177
|
+
|
|
178
|
+
# Reshape filters to match conv2d format [O_c, 1, wx, wy]
|
|
179
|
+
w = w[:, None, :, :] # [O_c, 1, wx, wy]
|
|
151
180
|
|
|
152
|
-
|
|
153
|
-
|
|
181
|
+
pad_x = wx // 2
|
|
182
|
+
pad_y = wy // 2
|
|
154
183
|
|
|
155
|
-
#
|
|
156
|
-
|
|
184
|
+
# Reflective padding to reduce edge artifacts
|
|
185
|
+
x_padded = F.pad(x, (pad_y, pad_y, pad_x, pad_x), mode="reflect")
|
|
157
186
|
|
|
158
|
-
#
|
|
159
|
-
|
|
187
|
+
# Apply convolution
|
|
188
|
+
y = F.conv2d(x_padded, w) # [B, O_c, Nx, Ny]
|
|
160
189
|
|
|
161
|
-
#
|
|
162
|
-
|
|
190
|
+
# Restore original leading dimensions
|
|
191
|
+
y = y.reshape(*leading_dims, O_c, Nx, Ny)
|
|
192
|
+
|
|
193
|
+
return y
|
|
163
194
|
|
|
164
195
|
def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
|
|
165
196
|
# to be written!!!
|
|
@@ -211,13 +242,13 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
211
242
|
xr = self.bk_real(x)
|
|
212
243
|
# xi = self.bk_imag(x)
|
|
213
244
|
|
|
214
|
-
r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
|
|
245
|
+
r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr + 1E-16)
|
|
215
246
|
# return r
|
|
216
247
|
# i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
|
|
217
248
|
|
|
218
249
|
return r
|
|
219
250
|
else:
|
|
220
|
-
return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x)
|
|
251
|
+
return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x + 1E-16)
|
|
221
252
|
|
|
222
253
|
def bk_square_comp(self, x):
|
|
223
254
|
if x.dtype == self.all_cbk_type:
|
|
@@ -356,9 +387,9 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
356
387
|
return self.backend.argmax(data)
|
|
357
388
|
|
|
358
389
|
def bk_reshape(self, data, shape):
|
|
359
|
-
if isinstance(data, np.ndarray):
|
|
360
|
-
|
|
361
|
-
return data.
|
|
390
|
+
#if isinstance(data, np.ndarray):
|
|
391
|
+
# return data.reshape(shape)
|
|
392
|
+
return data.reshape(shape)
|
|
362
393
|
|
|
363
394
|
def bk_repeat(self, data, nn, axis=0):
|
|
364
395
|
return self.backend.repeat_interleave(data, repeats=nn, dim=axis)
|
|
@@ -376,7 +407,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
376
407
|
return self.backend.unsqueeze(data, axis)
|
|
377
408
|
|
|
378
409
|
def bk_transpose(self, data, thelist):
|
|
379
|
-
return self.backend.transpose(data, thelist)
|
|
410
|
+
return self.backend.transpose(data, thelist[0], thelist[1])
|
|
380
411
|
|
|
381
412
|
def bk_concat(self, data, axis=None):
|
|
382
413
|
|
|
@@ -405,7 +436,9 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
405
436
|
return self.backend.zeros(shape, dtype=dtype).to(self.torch_device)
|
|
406
437
|
|
|
407
438
|
def bk_gather(self, data, idx, axis=0):
|
|
408
|
-
if axis ==
|
|
439
|
+
if axis == -1:
|
|
440
|
+
return data[...,idx]
|
|
441
|
+
elif axis == 0:
|
|
409
442
|
return data[idx]
|
|
410
443
|
elif axis == 1:
|
|
411
444
|
return data[:, idx]
|
|
@@ -413,7 +446,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
413
446
|
return data[:, :, idx]
|
|
414
447
|
elif axis == 3:
|
|
415
448
|
return data[:, :, :, idx]
|
|
416
|
-
return data[
|
|
449
|
+
return data[idx,...]
|
|
417
450
|
|
|
418
451
|
def bk_reverse(self, data, axis=0):
|
|
419
452
|
return self.backend.flip(data, dims=[axis])
|
foscat/CNN.py
CHANGED
|
@@ -9,13 +9,12 @@ class CNN:
|
|
|
9
9
|
|
|
10
10
|
def __init__(
|
|
11
11
|
self,
|
|
12
|
-
scat_operator=None,
|
|
13
12
|
nparam=1,
|
|
14
|
-
|
|
13
|
+
KERNELSZ=3,
|
|
14
|
+
NORIENT=4,
|
|
15
15
|
chanlist=[],
|
|
16
16
|
in_nside=1,
|
|
17
17
|
n_chan_in=1,
|
|
18
|
-
nbatch=1,
|
|
19
18
|
SEED=1234,
|
|
20
19
|
filename=None,
|
|
21
20
|
):
|
|
@@ -31,31 +30,30 @@ class CNN:
|
|
|
31
30
|
self.in_nside = outlist[4]
|
|
32
31
|
self.nbatch = outlist[1]
|
|
33
32
|
self.n_chan_in = outlist[8]
|
|
33
|
+
self.NORIENT = outlist[9]
|
|
34
34
|
self.x = self.scat_operator.backend.bk_cast(outlist[6])
|
|
35
35
|
self.out_nside = self.in_nside // (2**self.nscale)
|
|
36
36
|
else:
|
|
37
|
-
self.nscale =
|
|
38
|
-
self.nbatch = nbatch
|
|
37
|
+
self.nscale = len(chanlist)-1
|
|
39
38
|
self.npar = nparam
|
|
40
39
|
self.n_chan_in = n_chan_in
|
|
41
40
|
self.scat_operator = scat_operator
|
|
42
|
-
if
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
)
|
|
47
|
-
return None
|
|
41
|
+
if self.scat_operator is None:
|
|
42
|
+
self.scat_operator = sc.funct(
|
|
43
|
+
KERNELSZ=KERNELSZ,
|
|
44
|
+
NORIENT=NORIENT)
|
|
48
45
|
|
|
49
46
|
self.chanlist = chanlist
|
|
50
|
-
self.KERNELSZ = scat_operator.KERNELSZ
|
|
51
|
-
self.
|
|
47
|
+
self.KERNELSZ = self.scat_operator.KERNELSZ
|
|
48
|
+
self.NORIENT = self.scat_operator.NORIENT
|
|
49
|
+
self.all_type = self.scat_operator.all_type
|
|
52
50
|
self.in_nside = in_nside
|
|
53
51
|
self.out_nside = self.in_nside // (2**self.nscale)
|
|
54
|
-
|
|
52
|
+
self.backend = self.scat_operator.backend
|
|
55
53
|
np.random.seed(SEED)
|
|
56
|
-
self.x = scat_operator.backend.bk_cast(
|
|
57
|
-
np.random.
|
|
58
|
-
/ (self.KERNELSZ * self.KERNELSZ)
|
|
54
|
+
self.x = self.scat_operator.backend.bk_cast(
|
|
55
|
+
np.random.rand(self.get_number_of_weights())
|
|
56
|
+
/ (self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT)
|
|
59
57
|
)
|
|
60
58
|
|
|
61
59
|
def save(self, filename):
|
|
@@ -70,6 +68,7 @@ class CNN:
|
|
|
70
68
|
self.get_weights().numpy(),
|
|
71
69
|
self.all_type,
|
|
72
70
|
self.n_chan_in,
|
|
71
|
+
self.NORIENT,
|
|
73
72
|
]
|
|
74
73
|
|
|
75
74
|
myout = open("%s.pkl" % (filename), "wb")
|
|
@@ -82,8 +81,8 @@ class CNN:
|
|
|
82
81
|
totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1]
|
|
83
82
|
return (
|
|
84
83
|
self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]
|
|
85
|
-
+ totnchan * self.KERNELSZ * self.KERNELSZ
|
|
86
|
-
+ self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
|
|
84
|
+
+ totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)
|
|
85
|
+
+ self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
|
|
87
86
|
)
|
|
88
87
|
|
|
89
88
|
def set_weights(self, x):
|
|
@@ -95,30 +94,32 @@ class CNN:
|
|
|
95
94
|
def eval(self, im, indices=None, weights=None):
|
|
96
95
|
|
|
97
96
|
x = self.x
|
|
98
|
-
ww = self.
|
|
99
|
-
x[0 : self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]],
|
|
100
|
-
[self.KERNELSZ * self.KERNELSZ,
|
|
97
|
+
ww = self.backend.bk_reshape(
|
|
98
|
+
x[0 : self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]],
|
|
99
|
+
[self.n_chan_in, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[0]],
|
|
101
100
|
)
|
|
102
|
-
nn = self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
|
|
101
|
+
nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
|
|
103
102
|
|
|
104
103
|
im = self.scat_operator.healpix_layer(im, ww)
|
|
105
|
-
im = self.
|
|
104
|
+
im = self.backend.bk_relu(im)
|
|
105
|
+
|
|
106
|
+
im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
|
|
106
107
|
|
|
107
108
|
for k in range(self.nscale):
|
|
108
109
|
ww = self.scat_operator.backend.bk_reshape(
|
|
109
110
|
x[
|
|
110
111
|
nn : nn
|
|
111
112
|
+ self.KERNELSZ
|
|
112
|
-
* self.KERNELSZ
|
|
113
|
+
* (self.KERNELSZ//2+1)
|
|
113
114
|
* self.chanlist[k]
|
|
114
115
|
* self.chanlist[k + 1]
|
|
115
116
|
],
|
|
116
|
-
[self.KERNELSZ * self.KERNELSZ,
|
|
117
|
+
[self.chanlist[k], self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1]],
|
|
117
118
|
)
|
|
118
119
|
nn = (
|
|
119
120
|
nn
|
|
120
121
|
+ self.KERNELSZ
|
|
121
|
-
* self.KERNELSZ
|
|
122
|
+
* (self.KERNELSZ//2)
|
|
122
123
|
* self.chanlist[k]
|
|
123
124
|
* self.chanlist[k + 1]
|
|
124
125
|
)
|
|
@@ -129,7 +130,7 @@ class CNN:
|
|
|
129
130
|
im, ww, indices=indices[k], weights=weights[k]
|
|
130
131
|
)
|
|
131
132
|
im = self.scat_operator.backend.bk_relu(im)
|
|
132
|
-
im = self.
|
|
133
|
+
im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
|
|
133
134
|
|
|
134
135
|
ww = self.scat_operator.backend.bk_reshape(
|
|
135
136
|
x[
|
|
@@ -141,11 +142,11 @@ class CNN:
|
|
|
141
142
|
|
|
142
143
|
im = self.scat_operator.backend.bk_matmul(
|
|
143
144
|
self.scat_operator.backend.bk_reshape(
|
|
144
|
-
im, [
|
|
145
|
+
im, [im.shape[0], im.shape[1] * im.shape[2]]
|
|
145
146
|
),
|
|
146
147
|
ww,
|
|
147
148
|
)
|
|
148
|
-
im = self.scat_operator.backend.bk_reshape(im, [self.npar])
|
|
149
|
+
#im = self.scat_operator.backend.bk_reshape(im, [self.npar])
|
|
149
150
|
im = self.scat_operator.backend.bk_relu(im)
|
|
150
151
|
|
|
151
152
|
return im
|