foscat 2025.5.0__tar.gz → 2025.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {foscat-2025.5.0/src/foscat.egg-info → foscat-2025.6.1}/PKG-INFO +1 -1
  2. {foscat-2025.5.0 → foscat-2025.6.1}/pyproject.toml +1 -1
  3. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/BkTensorflow.py +138 -14
  4. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/BkTorch.py +90 -57
  5. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/CNN.py +31 -30
  6. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/FoCUS.py +640 -917
  7. foscat-2025.6.1/src/foscat/GCNN.py +137 -0
  8. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/Softmax.py +1 -0
  9. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/alm.py +2 -2
  10. foscat-2025.6.1/src/foscat/heal_NN.py +432 -0
  11. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/scat_cov.py +139 -96
  12. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/scat_cov_map2D.py +2 -2
  13. {foscat-2025.5.0 → foscat-2025.6.1/src/foscat.egg-info}/PKG-INFO +1 -1
  14. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat.egg-info/SOURCES.txt +1 -0
  15. foscat-2025.5.0/src/foscat/GCNN.py +0 -239
  16. {foscat-2025.5.0 → foscat-2025.6.1}/LICENSE +0 -0
  17. {foscat-2025.5.0 → foscat-2025.6.1}/README.md +0 -0
  18. {foscat-2025.5.0 → foscat-2025.6.1}/setup.cfg +0 -0
  19. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/BkBase.py +0 -0
  20. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/BkNumpy.py +0 -0
  21. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/CircSpline.py +0 -0
  22. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/Spline1D.py +0 -0
  23. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/Synthesis.py +0 -0
  24. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/__init__.py +0 -0
  25. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/backend.py +0 -0
  26. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/backend_tens.py +0 -0
  27. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/loss_backend_tens.py +0 -0
  28. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/loss_backend_torch.py +0 -0
  29. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/scat.py +0 -0
  30. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/scat1D.py +0 -0
  31. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/scat2D.py +0 -0
  32. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/scat_cov1D.py +0 -0
  33. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/scat_cov2D.py +0 -0
  34. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat/scat_cov_map.py +0 -0
  35. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat.egg-info/dependency_links.txt +0 -0
  36. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat.egg-info/requires.txt +0 -0
  37. {foscat-2025.5.0 → foscat-2025.6.1}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.5.0
3
+ Version: 2025.6.1
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "foscat"
3
- version = "2025.05.0"
3
+ version = "2025.06.1"
4
4
  description = "Generate synthetic Healpix or 2D data using Cross Scattering Transform"
5
5
  readme = "README.md"
6
6
  license = { text = "BSD-3-Clause" }
@@ -68,10 +68,16 @@ class BkTensorflow(BackendBase.BackendBase):
68
68
  print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
69
69
  sys.stdout.flush()
70
70
  self.ngpu = len(logical_gpus)
71
- gpuname = logical_gpus[self.gpupos % self.ngpu].name
72
- self.gpulist = {}
73
- for i in range(self.ngpu):
74
- self.gpulist[i] = logical_gpus[i].name
71
+ if self.ngpu > 0:
72
+ gpuname = logical_gpus[self.gpupos % self.ngpu].name
73
+ self.gpulist = {}
74
+ for i in range(self.ngpu):
75
+ self.gpulist[i] = logical_gpus[i].name
76
+ else:
77
+ gpuname = "CPU:0"
78
+ self.gpulist = {}
79
+ self.gpulist[0] = gpuname
80
+ self.ngpu = 1
75
81
 
76
82
  except RuntimeError as e:
77
83
  # Memory growth must be set before GPUs have been initialized
@@ -122,18 +128,136 @@ class BkTensorflow(BackendBase.BackendBase):
122
128
 
123
129
  return x_padded
124
130
 
125
- def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
126
- kx = w.shape[0]
127
- ky = w.shape[1]
128
- x_padded = self.periodic_pad(x, kx // 2, ky // 2)
129
- return self.backend.nn.conv2d(x_padded, w, strides=strides, padding="VALID")
131
+ def binned_mean(self, data, cell_ids):
132
+ """
133
+ data: Tensor of shape [..., N] (float32 or float64)
134
+ cell_ids: Tensor of shape [N], int indices in [0, n_bins)
135
+ Returns: mean per bin, shape [..., n_bins]
136
+ """
137
+ ishape = list(data.shape)
138
+ A = 1
139
+ for k in range(len(ishape) - 1):
140
+ A *= ishape[k]
141
+ N = tf.shape(data)[-1]
142
+
143
+ # Step 1: group indices
144
+ groups = tf.math.floordiv(cell_ids, 4) # [N]
145
+ unique_groups, I = tf.unique(groups) # I: [N]
146
+ n_bins = tf.shape(unique_groups)[0]
147
+
148
+ # Step 2: build I_tiled with batch + channel offsets
149
+ I_tiled = tf.tile(I[None, :], [A, 1]) # shape [, N]
150
+
151
+ # Offset index to flatten across [A, n_bins]
152
+ batch_channel_offsets = tf.range(A)[:, None] * n_bins
153
+ I_offset = I_tiled + batch_channel_offsets # shape [A, N]]
154
+
155
+ # Step 3: flatten data to shape [A, N]
156
+ data_reshaped = tf.reshape(data, [A, N]) # shape [A, N]
157
+
158
+ # Flatten all for scatter_nd
159
+ indices = tf.reshape(I_offset, [-1]) # [A*N]
160
+ values = tf.reshape(data_reshaped, [-1]) # [A*N]
161
+
162
+ """
163
+ # Prepare for scatter: indices → [A*N, 1]
164
+ scatter_indices = tf.expand_dims(indices, axis=1)
165
+ scatter_indices = tf.cast(scatter_indices, tf.int64)
166
+ """
167
+ total_bins = A * n_bins
168
+
169
+ # Step 4: sum per bin
170
+ sum_per_bin = tf.math.unsorted_segment_sum(values, indices, total_bins)
171
+ sum_per_bin = tf.reshape(sum_per_bin, ishape[0:-1] + [n_bins]) # [A, n_bins]
172
+
173
+ # Step 5: count per bin (same indices)
174
+ counts = tf.math.unsorted_segment_sum(1.0 + 0 * values, indices, total_bins)
175
+ # counts = tf.math.bincount(indices, minlength=total_bins, maxlength=total_bins)
176
+ counts = tf.reshape(counts, ishape[0:-1] + [n_bins])
177
+ # counts = tf.maximum(counts, 1) # Avoid division by zero
178
+ # counts = tf.cast(counts, dtype=data.dtype)
179
+
180
+ # Step 6: mean
181
+ mean_per_bin = sum_per_bin / counts # [B, A, n_bins]
182
+
183
+ return mean_per_bin, unique_groups
184
+
185
+ def conv2d(self, x, w):
186
+ """
187
+ Perform 2D convolution using TensorFlow.
188
+
189
+ Args:
190
+ x: Tensor of shape [..., Nx, Ny] – input
191
+ w: Tensor of shape [O_c, wx, wy] – conv weights
192
+
193
+ Returns:
194
+ Tensor of shape [..., O_c, Nx, Ny]
195
+ """
196
+ # Extract shape
197
+ *leading_dims, Nx, Ny = x.shape
198
+ O_c, wx, wy = w.shape
199
+
200
+ # Flatten leading dims into a batch dimension
201
+ B = tf.reduce_prod(leading_dims) if leading_dims else 1
202
+ x = tf.reshape(x, [B, Nx, Ny, 1]) # TensorFlow format: [B, H, W, C_in=1]
203
+
204
+ # Reshape weights to [wx, wy, in_channels=1, out_channels]
205
+ w = tf.reshape(w, [O_c, wx, wy])
206
+ w = tf.transpose(w, perm=[1, 2, 0]) # [wx, wy, O_c]
207
+ w = tf.reshape(w, [wx, wy, 1, O_c]) # [wx, wy, C_in=1, C_out]
208
+
209
+ # Apply 'reflect' padding manually
210
+ pad_x = wx // 2
211
+ pad_y = wy // 2
212
+ x_padded = tf.pad(
213
+ x, [[0, 0], [pad_x, pad_x], [pad_y, pad_y], [0, 0]], mode="REFLECT"
214
+ )
215
+
216
+ # Perform convolution
217
+ y = tf.nn.conv2d(
218
+ x_padded, w, strides=[1, 1, 1, 1], padding="VALID"
219
+ ) # [B, Nx, Ny, O_c]
220
+
221
+ # Transpose back to match original format: [..., O_c, Nx, Ny]
222
+ y = tf.transpose(y, [0, 3, 1, 2]) # [B, O_c, Nx, Ny]
223
+ y = tf.reshape(y, [*leading_dims, O_c, Nx, Ny])
224
+
225
+ return y
226
+
227
+ def conv1d(self, x, w):
228
+ """
229
+ Perform 1D convolution using TensorFlow.
230
+
231
+ Args:
232
+ x: Tensor of shape [..., N] – input
233
+ w: Tensor of shape [k] – conv weights
234
+
235
+ Returns:
236
+ Tensor of shape [...,N]
237
+ """
238
+ # Extract shapes
239
+ *leading_dims, N = x.shape
240
+ k = w.shape[0]
241
+
242
+ # Flatten leading dims into batch dimension
243
+ B = tf.reduce_prod(leading_dims) if leading_dims else 1
244
+ x = tf.reshape(x, [B, N, 1]) # TensorFlow 1D format: [B, L, C=1]
245
+
246
+ # Prepare weights: [k, in_channels=1, out_channels=O_c]
247
+ w = tf.reshape(w, [k, 1, 1])
248
+
249
+ # Apply 'reflect' padding
250
+ pad = k // 2
251
+ x_padded = tf.pad(x, [[0, 0], [pad, pad], [0, 0]], mode="REFLECT")
252
+
253
+ # Perform convolution
254
+ y = tf.nn.conv1d(x_padded, w, stride=1, padding="VALID") # [B, N, O_c]
130
255
 
131
- def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
132
- kx = w.shape[0]
133
- paddings = self.backend.constant([[0, 0], [kx // 2, kx // 2], [0, 0]])
134
- tmp = self.backend.pad(x, paddings, "SYMMETRIC")
256
+ # Transpose to [B, O_c, N] and reshape back
257
+ y = tf.transpose(y, [0, 2, 1]) # [B, 1, N]
258
+ y = tf.reshape(y, [*leading_dims, N]) # [..., N]
135
259
 
136
- return self.backend.nn.conv1d(tmp, w, stride=strides, padding="VALID")
260
+ return y
137
261
 
138
262
  def bk_threshold(self, x, threshold, greater=True):
139
263
 
@@ -2,6 +2,7 @@ import sys
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
+ import torch.nn.functional as F
5
6
 
6
7
  import foscat.BkBase as BackendBase
7
8
 
@@ -62,55 +63,69 @@ class BkTorch(BackendBase.BackendBase):
62
63
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
63
64
  )
64
65
 
66
+ import torch
67
+
65
68
  def binned_mean(self, data, cell_ids):
66
69
  """
67
- data: Tensor of shape [B, N, A]
68
- I: Tensor of shape [N], integer indices in [0, n_bins)
69
- Returns: mean per bin, shape [B, n_bins, A]
70
+ Compute the mean over groups of 4 nested HEALPix cells (nside → nside/2).
71
+
72
+ Args:
73
+ data (torch.Tensor): Tensor of shape [..., N], where N is the number of HEALPix cells.
74
+ cell_ids (torch.LongTensor): Tensor of shape [N], with cell indices (nested ordering).
75
+
76
+ Returns:
77
+ torch.Tensor: Tensor of shape [..., n_bins], with averaged values per group of 4 cells.
70
78
  """
71
- groups = cell_ids // 4 # [N]
79
+ if isinstance(data, np.ndarray):
80
+ data = torch.from_numpy(data).to(
81
+ dtype=torch.float32, device=self.torch_device
82
+ )
83
+ if isinstance(cell_ids, np.ndarray):
84
+ cell_ids = torch.from_numpy(cell_ids).to(
85
+ dtype=torch.long, device=self.torch_device
86
+ )
72
87
 
73
- unique_groups, I = np.unique(groups, return_inverse=True)
88
+ # Compute supercell ids by grouping 4 nested cells together
89
+ groups = cell_ids // 4
74
90
 
91
+ # Get unique group ids and inverse mapping
92
+ unique_groups, inverse_indices = torch.unique(groups, return_inverse=True)
75
93
  n_bins = unique_groups.shape[0]
76
94
 
77
- B = data.shape[0]
95
+ # Flatten all leading dimensions into a single batch dimension
96
+ original_shape = data.shape[:-1]
97
+ N = data.shape[-1]
98
+ data_flat = data.reshape(-1, N) # Shape: [B, N]
78
99
 
79
- counts = torch.bincount(torch.tensor(I).to(data.device))[None, :]
100
+ # Prepare to compute sums using scatter_add
101
+ B = data_flat.shape[0]
80
102
 
81
- I = np.tile(I, B) + np.tile(n_bins * np.arange(B, dtype="int"), data.shape[1])
103
+ # Repeat inverse indices for each batch element
104
+ idx = inverse_indices.repeat(B, 1) # Shape: [B, N]
82
105
 
83
- if len(data.shape) == 3:
84
- A = data.shape[2]
85
- I = np.repeat(I, A) * A + np.repeat(
86
- np.arange(A, dtype="int"), data.shape[1] * B
87
- )
106
+ # Offset indices to simulate a per-batch scatter into [B * n_bins]
107
+ batch_offsets = torch.arange(B, device=data.device).unsqueeze(1) * n_bins
108
+ idx_offset = idx + batch_offsets # Shape: [B, N]
88
109
 
89
- I = torch.tensor(I).to(data.device)
110
+ # Flatten everything for scatter
111
+ idx_offset_flat = idx_offset.flatten()
112
+ data_flat_flat = data_flat.flatten()
90
113
 
91
- # Comptage par bin
92
- if len(data.shape) == 2:
93
- sum_per_bin = torch.zeros(
94
- [B * n_bins], dtype=data.dtype, device=data.device
95
- )
96
- sum_per_bin = sum_per_bin.scatter_add(
97
- 0, I, self.bk_reshape(data, B * data.shape[1])
98
- ).reshape(B, n_bins)
114
+ # Accumulate sums per bin
115
+ out = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
116
+ out = out.scatter_add(0, idx_offset_flat, data_flat_flat)
99
117
 
100
- mean_per_bin = sum_per_bin / counts # [B, n_bins, A]
101
- else:
102
- sum_per_bin = torch.zeros(
103
- [B * n_bins * A], dtype=data.dtype, device=data.device
104
- )
105
- sum_per_bin = sum_per_bin.scatter_add(
106
- 0, I, self.bk_reshape(data, B * data.shape[1] * A)
107
- ).reshape(
108
- B, n_bins, A
109
- ) # [B, n_bins]
118
+ # Count number of elements per bin (to compute mean)
119
+ ones = torch.ones_like(data_flat_flat)
120
+ counts = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
121
+ counts = counts.scatter_add(0, idx_offset_flat, ones)
110
122
 
111
- mean_per_bin = sum_per_bin / counts[:, :, None] # [B, n_bins, A]
123
+ # Compute mean
124
+ mean = out / counts # Shape: [B * n_bins]
125
+ mean = mean.view(B, n_bins)
112
126
 
113
- return mean_per_bin, unique_groups
127
+ # Restore original leading dimensions
128
+ return mean.view(*original_shape, n_bins), unique_groups
114
129
 
115
130
  def average_by_cell_group(data, cell_ids):
116
131
  """
@@ -134,11 +149,7 @@ class BkTorch(BackendBase.BackendBase):
134
149
  # -- BACKEND DEFINITION --
135
150
  # ---------------------------------------------−---------
136
151
  def bk_SparseTensor(self, indice, w, dense_shape=[]):
137
- return (
138
- self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
139
- .to_sparse_csr()
140
- .to(self.torch_device)
141
- )
152
+ return self.backend.sparse_coo_tensor(indice.T, w, dense_shape).to_sparse_csr().to(self.torch_device)
142
153
 
143
154
  def bk_stack(self, list, axis=0):
144
155
  return self.backend.stack(list, axis=axis).to(self.torch_device)
@@ -146,20 +157,40 @@ class BkTorch(BackendBase.BackendBase):
146
157
  def bk_sparse_dense_matmul(self, smat, mat):
147
158
  return smat.matmul(mat)
148
159
 
149
- def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
150
- import torch.nn.functional as F
160
+ def conv2d(self, x, w):
161
+ """
162
+ Perform 2D convolution using PyTorch format.
163
+
164
+ Args:
165
+ x: Tensor of shape [..., Nx, Ny] – input
166
+ w: Tensor of shape [O_c, wx, wy] – conv weights
167
+
168
+ Returns:
169
+ Tensor of shape [..., O_c, Nx, Ny]
170
+ """
171
+ *leading_dims, Nx, Ny = x.shape # extract leading dims
172
+ O_c, wx, wy = w.shape
173
+
174
+ # Flatten leading dims into batch dimension
175
+ B = int(torch.prod(torch.tensor(leading_dims))) if leading_dims else 1
176
+ x = x.reshape(B, 1, Nx, Ny) # [B, 1, Nx, Ny]
177
+
178
+ # Reshape filters to match conv2d format [O_c, 1, wx, wy]
179
+ w = w[:, None, :, :] # [O_c, 1, wx, wy]
151
180
 
152
- lx = x.permute(0, 3, 1, 2)
153
- wx = w.permute(3, 2, 0, 1) # de (5, 5, 1, 4) à (4, 1, 5, 5)
181
+ pad_x = wx // 2
182
+ pad_y = wy // 2
154
183
 
155
- # Calculer le padding symétrique
156
- kx, ky = w.shape[0], w.shape[1]
184
+ # Reflective padding to reduce edge artifacts
185
+ x_padded = F.pad(x, (pad_y, pad_y, pad_x, pad_x), mode="reflect")
157
186
 
158
- # Appliquer le padding
159
- x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode="circular")
187
+ # Apply convolution
188
+ y = F.conv2d(x_padded, w) # [B, O_c, Nx, Ny]
160
189
 
161
- # Appliquer la convolution
162
- return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0, 2, 3, 1)
190
+ # Restore original leading dimensions
191
+ y = y.reshape(*leading_dims, O_c, Nx, Ny)
192
+
193
+ return y
163
194
 
164
195
  def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
165
196
  # to be written!!!
@@ -211,13 +242,13 @@ class BkTorch(BackendBase.BackendBase):
211
242
  xr = self.bk_real(x)
212
243
  # xi = self.bk_imag(x)
213
244
 
214
- r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
245
+ r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr + 1E-16)
215
246
  # return r
216
247
  # i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
217
248
 
218
249
  return r
219
250
  else:
220
- return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x)
251
+ return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x + 1E-16)
221
252
 
222
253
  def bk_square_comp(self, x):
223
254
  if x.dtype == self.all_cbk_type:
@@ -356,9 +387,9 @@ class BkTorch(BackendBase.BackendBase):
356
387
  return self.backend.argmax(data)
357
388
 
358
389
  def bk_reshape(self, data, shape):
359
- if isinstance(data, np.ndarray):
360
- return data.reshape(shape)
361
- return data.view(shape)
390
+ #if isinstance(data, np.ndarray):
391
+ # return data.reshape(shape)
392
+ return data.reshape(shape)
362
393
 
363
394
  def bk_repeat(self, data, nn, axis=0):
364
395
  return self.backend.repeat_interleave(data, repeats=nn, dim=axis)
@@ -376,7 +407,7 @@ class BkTorch(BackendBase.BackendBase):
376
407
  return self.backend.unsqueeze(data, axis)
377
408
 
378
409
  def bk_transpose(self, data, thelist):
379
- return self.backend.transpose(data, thelist)
410
+ return self.backend.transpose(data, thelist[0], thelist[1])
380
411
 
381
412
  def bk_concat(self, data, axis=None):
382
413
 
@@ -405,7 +436,9 @@ class BkTorch(BackendBase.BackendBase):
405
436
  return self.backend.zeros(shape, dtype=dtype).to(self.torch_device)
406
437
 
407
438
  def bk_gather(self, data, idx, axis=0):
408
- if axis == 0:
439
+ if axis == -1:
440
+ return data[...,idx]
441
+ elif axis == 0:
409
442
  return data[idx]
410
443
  elif axis == 1:
411
444
  return data[:, idx]
@@ -413,7 +446,7 @@ class BkTorch(BackendBase.BackendBase):
413
446
  return data[:, :, idx]
414
447
  elif axis == 3:
415
448
  return data[:, :, :, idx]
416
- return data[:, :, :, :, idx]
449
+ return data[idx,...]
417
450
 
418
451
  def bk_reverse(self, data, axis=0):
419
452
  return self.backend.flip(data, dims=[axis])
@@ -9,13 +9,12 @@ class CNN:
9
9
 
10
10
  def __init__(
11
11
  self,
12
- scat_operator=None,
13
12
  nparam=1,
14
- nscale=1,
13
+ KERNELSZ=3,
14
+ NORIENT=4,
15
15
  chanlist=[],
16
16
  in_nside=1,
17
17
  n_chan_in=1,
18
- nbatch=1,
19
18
  SEED=1234,
20
19
  filename=None,
21
20
  ):
@@ -31,31 +30,30 @@ class CNN:
31
30
  self.in_nside = outlist[4]
32
31
  self.nbatch = outlist[1]
33
32
  self.n_chan_in = outlist[8]
33
+ self.NORIENT = outlist[9]
34
34
  self.x = self.scat_operator.backend.bk_cast(outlist[6])
35
35
  self.out_nside = self.in_nside // (2**self.nscale)
36
36
  else:
37
- self.nscale = nscale
38
- self.nbatch = nbatch
37
+ self.nscale = len(chanlist)-1
39
38
  self.npar = nparam
40
39
  self.n_chan_in = n_chan_in
41
40
  self.scat_operator = scat_operator
42
- if len(chanlist) != nscale + 1:
43
- print(
44
- "len of chanlist (here %d) should of nscale+1 (here %d)"
45
- % (len(chanlist), nscale + 1)
46
- )
47
- return None
41
+ if self.scat_operator is None:
42
+ self.scat_operator = sc.funct(
43
+ KERNELSZ=KERNELSZ,
44
+ NORIENT=NORIENT)
48
45
 
49
46
  self.chanlist = chanlist
50
- self.KERNELSZ = scat_operator.KERNELSZ
51
- self.all_type = scat_operator.all_type
47
+ self.KERNELSZ = self.scat_operator.KERNELSZ
48
+ self.NORIENT = self.scat_operator.NORIENT
49
+ self.all_type = self.scat_operator.all_type
52
50
  self.in_nside = in_nside
53
51
  self.out_nside = self.in_nside // (2**self.nscale)
54
-
52
+ self.backend = self.scat_operator.backend
55
53
  np.random.seed(SEED)
56
- self.x = scat_operator.backend.bk_cast(
57
- np.random.randn(self.get_number_of_weights())
58
- / (self.KERNELSZ * self.KERNELSZ)
54
+ self.x = self.scat_operator.backend.bk_cast(
55
+ np.random.rand(self.get_number_of_weights())
56
+ / (self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT)
59
57
  )
60
58
 
61
59
  def save(self, filename):
@@ -70,6 +68,7 @@ class CNN:
70
68
  self.get_weights().numpy(),
71
69
  self.all_type,
72
70
  self.n_chan_in,
71
+ self.NORIENT,
73
72
  ]
74
73
 
75
74
  myout = open("%s.pkl" % (filename), "wb")
@@ -82,8 +81,8 @@ class CNN:
82
81
  totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1]
83
82
  return (
84
83
  self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]
85
- + totnchan * self.KERNELSZ * self.KERNELSZ
86
- + self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
84
+ + totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)
85
+ + self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
87
86
  )
88
87
 
89
88
  def set_weights(self, x):
@@ -95,30 +94,32 @@ class CNN:
95
94
  def eval(self, im, indices=None, weights=None):
96
95
 
97
96
  x = self.x
98
- ww = self.scat_operator.backend.bk_reshape(
99
- x[0 : self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]],
100
- [self.KERNELSZ * self.KERNELSZ, self.n_chan_in, self.chanlist[0]],
97
+ ww = self.backend.bk_reshape(
98
+ x[0 : self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]],
99
+ [self.n_chan_in, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[0]],
101
100
  )
102
- nn = self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
101
+ nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
103
102
 
104
103
  im = self.scat_operator.healpix_layer(im, ww)
105
- im = self.scat_operator.backend.bk_relu(im)
104
+ im = self.backend.bk_relu(im)
105
+
106
+ im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
106
107
 
107
108
  for k in range(self.nscale):
108
109
  ww = self.scat_operator.backend.bk_reshape(
109
110
  x[
110
111
  nn : nn
111
112
  + self.KERNELSZ
112
- * self.KERNELSZ
113
+ * (self.KERNELSZ//2+1)
113
114
  * self.chanlist[k]
114
115
  * self.chanlist[k + 1]
115
116
  ],
116
- [self.KERNELSZ * self.KERNELSZ, self.chanlist[k], self.chanlist[k + 1]],
117
+ [self.chanlist[k], self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1]],
117
118
  )
118
119
  nn = (
119
120
  nn
120
121
  + self.KERNELSZ
121
- * self.KERNELSZ
122
+ * (self.KERNELSZ//2)
122
123
  * self.chanlist[k]
123
124
  * self.chanlist[k + 1]
124
125
  )
@@ -129,7 +130,7 @@ class CNN:
129
130
  im, ww, indices=indices[k], weights=weights[k]
130
131
  )
131
132
  im = self.scat_operator.backend.bk_relu(im)
132
- im = self.scat_operator.ud_grade_2(im, axis=0)
133
+ im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
133
134
 
134
135
  ww = self.scat_operator.backend.bk_reshape(
135
136
  x[
@@ -141,11 +142,11 @@ class CNN:
141
142
 
142
143
  im = self.scat_operator.backend.bk_matmul(
143
144
  self.scat_operator.backend.bk_reshape(
144
- im, [1, 12 * self.out_nside**2 * self.chanlist[self.nscale]]
145
+ im, [im.shape[0], im.shape[1] * im.shape[2]]
145
146
  ),
146
147
  ww,
147
148
  )
148
- im = self.scat_operator.backend.bk_reshape(im, [self.npar])
149
+ #im = self.scat_operator.backend.bk_reshape(im, [self.npar])
149
150
  im = self.scat_operator.backend.bk_relu(im)
150
151
 
151
152
  return im