foscat 2025.5.0__py3-none-any.whl → 2025.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/BkTensorflow.py CHANGED
@@ -68,10 +68,16 @@ class BkTensorflow(BackendBase.BackendBase):
68
68
  print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
69
69
  sys.stdout.flush()
70
70
  self.ngpu = len(logical_gpus)
71
- gpuname = logical_gpus[self.gpupos % self.ngpu].name
72
- self.gpulist = {}
73
- for i in range(self.ngpu):
74
- self.gpulist[i] = logical_gpus[i].name
71
+ if self.ngpu > 0:
72
+ gpuname = logical_gpus[self.gpupos % self.ngpu].name
73
+ self.gpulist = {}
74
+ for i in range(self.ngpu):
75
+ self.gpulist[i] = logical_gpus[i].name
76
+ else:
77
+ gpuname = "CPU:0"
78
+ self.gpulist = {}
79
+ self.gpulist[0] = gpuname
80
+ self.ngpu = 1
75
81
 
76
82
  except RuntimeError as e:
77
83
  # Memory growth must be set before GPUs have been initialized
@@ -122,18 +128,136 @@ class BkTensorflow(BackendBase.BackendBase):
122
128
 
123
129
  return x_padded
124
130
 
125
- def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
126
- kx = w.shape[0]
127
- ky = w.shape[1]
128
- x_padded = self.periodic_pad(x, kx // 2, ky // 2)
129
- return self.backend.nn.conv2d(x_padded, w, strides=strides, padding="VALID")
131
+ def binned_mean(self, data, cell_ids):
132
+ """
133
+ data: Tensor of shape [..., N] (float32 or float64)
134
+ cell_ids: Tensor of shape [N], int indices in [0, n_bins)
135
+ Returns: mean per bin, shape [..., n_bins]
136
+ """
137
+ ishape = list(data.shape)
138
+ A = 1
139
+ for k in range(len(ishape) - 1):
140
+ A *= ishape[k]
141
+ N = tf.shape(data)[-1]
142
+
143
+ # Step 1: group indices
144
+ groups = tf.math.floordiv(cell_ids, 4) # [N]
145
+ unique_groups, I = tf.unique(groups) # I: [N]
146
+ n_bins = tf.shape(unique_groups)[0]
147
+
148
+ # Step 2: build I_tiled with batch + channel offsets
149
+ I_tiled = tf.tile(I[None, :], [A, 1]) # shape [, N]
150
+
151
+ # Offset index to flatten across [A, n_bins]
152
+ batch_channel_offsets = tf.range(A)[:, None] * n_bins
153
+ I_offset = I_tiled + batch_channel_offsets # shape [A, N]]
154
+
155
+ # Step 3: flatten data to shape [A, N]
156
+ data_reshaped = tf.reshape(data, [A, N]) # shape [A, N]
157
+
158
+ # Flatten all for scatter_nd
159
+ indices = tf.reshape(I_offset, [-1]) # [A*N]
160
+ values = tf.reshape(data_reshaped, [-1]) # [A*N]
161
+
162
+ """
163
+ # Prepare for scatter: indices → [A*N, 1]
164
+ scatter_indices = tf.expand_dims(indices, axis=1)
165
+ scatter_indices = tf.cast(scatter_indices, tf.int64)
166
+ """
167
+ total_bins = A * n_bins
168
+
169
+ # Step 4: sum per bin
170
+ sum_per_bin = tf.math.unsorted_segment_sum(values, indices, total_bins)
171
+ sum_per_bin = tf.reshape(sum_per_bin, ishape[0:-1] + [n_bins]) # [A, n_bins]
172
+
173
+ # Step 5: count per bin (same indices)
174
+ counts = tf.math.unsorted_segment_sum(1.0 + 0 * values, indices, total_bins)
175
+ # counts = tf.math.bincount(indices, minlength=total_bins, maxlength=total_bins)
176
+ counts = tf.reshape(counts, ishape[0:-1] + [n_bins])
177
+ # counts = tf.maximum(counts, 1) # Avoid division by zero
178
+ # counts = tf.cast(counts, dtype=data.dtype)
179
+
180
+ # Step 6: mean
181
+ mean_per_bin = sum_per_bin / counts # [B, A, n_bins]
182
+
183
+ return mean_per_bin, unique_groups
184
+
185
+ def conv2d(self, x, w):
186
+ """
187
+ Perform 2D convolution using TensorFlow.
188
+
189
+ Args:
190
+ x: Tensor of shape [..., Nx, Ny] – input
191
+ w: Tensor of shape [O_c, wx, wy] – conv weights
192
+
193
+ Returns:
194
+ Tensor of shape [..., O_c, Nx, Ny]
195
+ """
196
+ # Extract shape
197
+ *leading_dims, Nx, Ny = x.shape
198
+ O_c, wx, wy = w.shape
199
+
200
+ # Flatten leading dims into a batch dimension
201
+ B = tf.reduce_prod(leading_dims) if leading_dims else 1
202
+ x = tf.reshape(x, [B, Nx, Ny, 1]) # TensorFlow format: [B, H, W, C_in=1]
203
+
204
+ # Reshape weights to [wx, wy, in_channels=1, out_channels]
205
+ w = tf.reshape(w, [O_c, wx, wy])
206
+ w = tf.transpose(w, perm=[1, 2, 0]) # [wx, wy, O_c]
207
+ w = tf.reshape(w, [wx, wy, 1, O_c]) # [wx, wy, C_in=1, C_out]
208
+
209
+ # Apply 'reflect' padding manually
210
+ pad_x = wx // 2
211
+ pad_y = wy // 2
212
+ x_padded = tf.pad(
213
+ x, [[0, 0], [pad_x, pad_x], [pad_y, pad_y], [0, 0]], mode="REFLECT"
214
+ )
215
+
216
+ # Perform convolution
217
+ y = tf.nn.conv2d(
218
+ x_padded, w, strides=[1, 1, 1, 1], padding="VALID"
219
+ ) # [B, Nx, Ny, O_c]
220
+
221
+ # Transpose back to match original format: [..., O_c, Nx, Ny]
222
+ y = tf.transpose(y, [0, 3, 1, 2]) # [B, O_c, Nx, Ny]
223
+ y = tf.reshape(y, [*leading_dims, O_c, Nx, Ny])
224
+
225
+ return y
226
+
227
+ def conv1d(self, x, w):
228
+ """
229
+ Perform 1D convolution using TensorFlow.
230
+
231
+ Args:
232
+ x: Tensor of shape [..., N] – input
233
+ w: Tensor of shape [k] – conv weights
234
+
235
+ Returns:
236
+ Tensor of shape [...,N]
237
+ """
238
+ # Extract shapes
239
+ *leading_dims, N = x.shape
240
+ k = w.shape[0]
241
+
242
+ # Flatten leading dims into batch dimension
243
+ B = tf.reduce_prod(leading_dims) if leading_dims else 1
244
+ x = tf.reshape(x, [B, N, 1]) # TensorFlow 1D format: [B, L, C=1]
245
+
246
+ # Prepare weights: [k, in_channels=1, out_channels=O_c]
247
+ w = tf.reshape(w, [k, 1, 1])
248
+
249
+ # Apply 'reflect' padding
250
+ pad = k // 2
251
+ x_padded = tf.pad(x, [[0, 0], [pad, pad], [0, 0]], mode="REFLECT")
252
+
253
+ # Perform convolution
254
+ y = tf.nn.conv1d(x_padded, w, stride=1, padding="VALID") # [B, N, O_c]
130
255
 
131
- def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
132
- kx = w.shape[0]
133
- paddings = self.backend.constant([[0, 0], [kx // 2, kx // 2], [0, 0]])
134
- tmp = self.backend.pad(x, paddings, "SYMMETRIC")
256
+ # Transpose to [B, O_c, N] and reshape back
257
+ y = tf.transpose(y, [0, 2, 1]) # [B, 1, N]
258
+ y = tf.reshape(y, [*leading_dims, N]) # [..., N]
135
259
 
136
- return self.backend.nn.conv1d(tmp, w, stride=strides, padding="VALID")
260
+ return y
137
261
 
138
262
  def bk_threshold(self, x, threshold, greater=True):
139
263
 
foscat/BkTorch.py CHANGED
@@ -2,6 +2,7 @@ import sys
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
+ import torch.nn.functional as F
5
6
 
6
7
  import foscat.BkBase as BackendBase
7
8
 
@@ -62,55 +63,69 @@ class BkTorch(BackendBase.BackendBase):
62
63
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
63
64
  )
64
65
 
66
+ import torch
67
+
65
68
  def binned_mean(self, data, cell_ids):
66
69
  """
67
- data: Tensor of shape [B, N, A]
68
- I: Tensor of shape [N], integer indices in [0, n_bins)
69
- Returns: mean per bin, shape [B, n_bins, A]
70
+ Compute the mean over groups of 4 nested HEALPix cells (nside → nside/2).
71
+
72
+ Args:
73
+ data (torch.Tensor): Tensor of shape [..., N], where N is the number of HEALPix cells.
74
+ cell_ids (torch.LongTensor): Tensor of shape [N], with cell indices (nested ordering).
75
+
76
+ Returns:
77
+ torch.Tensor: Tensor of shape [..., n_bins], with averaged values per group of 4 cells.
70
78
  """
71
- groups = cell_ids // 4 # [N]
79
+ if isinstance(data, np.ndarray):
80
+ data = torch.from_numpy(data).to(
81
+ dtype=torch.float32, device=self.torch_device
82
+ )
83
+ if isinstance(cell_ids, np.ndarray):
84
+ cell_ids = torch.from_numpy(cell_ids).to(
85
+ dtype=torch.long, device=self.torch_device
86
+ )
72
87
 
73
- unique_groups, I = np.unique(groups, return_inverse=True)
88
+ # Compute supercell ids by grouping 4 nested cells together
89
+ groups = cell_ids // 4
74
90
 
91
+ # Get unique group ids and inverse mapping
92
+ unique_groups, inverse_indices = torch.unique(groups, return_inverse=True)
75
93
  n_bins = unique_groups.shape[0]
76
94
 
77
- B = data.shape[0]
95
+ # Flatten all leading dimensions into a single batch dimension
96
+ original_shape = data.shape[:-1]
97
+ N = data.shape[-1]
98
+ data_flat = data.reshape(-1, N) # Shape: [B, N]
78
99
 
79
- counts = torch.bincount(torch.tensor(I).to(data.device))[None, :]
100
+ # Prepare to compute sums using scatter_add
101
+ B = data_flat.shape[0]
80
102
 
81
- I = np.tile(I, B) + np.tile(n_bins * np.arange(B, dtype="int"), data.shape[1])
103
+ # Repeat inverse indices for each batch element
104
+ idx = inverse_indices.repeat(B, 1) # Shape: [B, N]
82
105
 
83
- if len(data.shape) == 3:
84
- A = data.shape[2]
85
- I = np.repeat(I, A) * A + np.repeat(
86
- np.arange(A, dtype="int"), data.shape[1] * B
87
- )
106
+ # Offset indices to simulate a per-batch scatter into [B * n_bins]
107
+ batch_offsets = torch.arange(B, device=data.device).unsqueeze(1) * n_bins
108
+ idx_offset = idx + batch_offsets # Shape: [B, N]
88
109
 
89
- I = torch.tensor(I).to(data.device)
110
+ # Flatten everything for scatter
111
+ idx_offset_flat = idx_offset.flatten()
112
+ data_flat_flat = data_flat.flatten()
90
113
 
91
- # Comptage par bin
92
- if len(data.shape) == 2:
93
- sum_per_bin = torch.zeros(
94
- [B * n_bins], dtype=data.dtype, device=data.device
95
- )
96
- sum_per_bin = sum_per_bin.scatter_add(
97
- 0, I, self.bk_reshape(data, B * data.shape[1])
98
- ).reshape(B, n_bins)
114
+ # Accumulate sums per bin
115
+ out = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
116
+ out = out.scatter_add(0, idx_offset_flat, data_flat_flat)
99
117
 
100
- mean_per_bin = sum_per_bin / counts # [B, n_bins, A]
101
- else:
102
- sum_per_bin = torch.zeros(
103
- [B * n_bins * A], dtype=data.dtype, device=data.device
104
- )
105
- sum_per_bin = sum_per_bin.scatter_add(
106
- 0, I, self.bk_reshape(data, B * data.shape[1] * A)
107
- ).reshape(
108
- B, n_bins, A
109
- ) # [B, n_bins]
118
+ # Count number of elements per bin (to compute mean)
119
+ ones = torch.ones_like(data_flat_flat)
120
+ counts = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
121
+ counts = counts.scatter_add(0, idx_offset_flat, ones)
110
122
 
111
- mean_per_bin = sum_per_bin / counts[:, :, None] # [B, n_bins, A]
123
+ # Compute mean
124
+ mean = out / counts # Shape: [B * n_bins]
125
+ mean = mean.view(B, n_bins)
112
126
 
113
- return mean_per_bin, unique_groups
127
+ # Restore original leading dimensions
128
+ return mean.view(*original_shape, n_bins), unique_groups
114
129
 
115
130
  def average_by_cell_group(data, cell_ids):
116
131
  """
@@ -146,20 +161,40 @@ class BkTorch(BackendBase.BackendBase):
146
161
  def bk_sparse_dense_matmul(self, smat, mat):
147
162
  return smat.matmul(mat)
148
163
 
149
- def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
150
- import torch.nn.functional as F
164
+ def conv2d(self, x, w):
165
+ """
166
+ Perform 2D convolution using PyTorch format.
167
+
168
+ Args:
169
+ x: Tensor of shape [..., Nx, Ny] – input
170
+ w: Tensor of shape [O_c, wx, wy] – conv weights
171
+
172
+ Returns:
173
+ Tensor of shape [..., O_c, Nx, Ny]
174
+ """
175
+ *leading_dims, Nx, Ny = x.shape # extract leading dims
176
+ O_c, wx, wy = w.shape
151
177
 
152
- lx = x.permute(0, 3, 1, 2)
153
- wx = w.permute(3, 2, 0, 1) # de (5, 5, 1, 4) à (4, 1, 5, 5)
178
+ # Flatten leading dims into batch dimension
179
+ B = int(torch.prod(torch.tensor(leading_dims))) if leading_dims else 1
180
+ x = x.reshape(B, 1, Nx, Ny) # [B, 1, Nx, Ny]
154
181
 
155
- # Calculer le padding symétrique
156
- kx, ky = w.shape[0], w.shape[1]
182
+ # Reshape filters to match conv2d format [O_c, 1, wx, wy]
183
+ w = w[:, None, :, :] # [O_c, 1, wx, wy]
157
184
 
158
- # Appliquer le padding
159
- x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode="circular")
185
+ pad_x = wx // 2
186
+ pad_y = wy // 2
160
187
 
161
- # Appliquer la convolution
162
- return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0, 2, 3, 1)
188
+ # Reflective padding to reduce edge artifacts
189
+ x_padded = F.pad(x, (pad_y, pad_y, pad_x, pad_x), mode="reflect")
190
+
191
+ # Apply convolution
192
+ y = F.conv2d(x_padded, w) # [B, O_c, Nx, Ny]
193
+
194
+ # Restore original leading dimensions
195
+ y = y.reshape(*leading_dims, O_c, Nx, Ny)
196
+
197
+ return y
163
198
 
164
199
  def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
165
200
  # to be written!!!
@@ -376,7 +411,7 @@ class BkTorch(BackendBase.BackendBase):
376
411
  return self.backend.unsqueeze(data, axis)
377
412
 
378
413
  def bk_transpose(self, data, thelist):
379
- return self.backend.transpose(data, thelist)
414
+ return self.backend.transpose(data, thelist[0], thelist[1])
380
415
 
381
416
  def bk_concat(self, data, axis=None):
382
417