foscat 3.8.2__py3-none-any.whl → 2025.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/BkTorch.py CHANGED
@@ -1,15 +1,20 @@
1
1
  import sys
2
- import foscat.BkBase as BackendBase
2
+
3
3
  import numpy as np
4
4
  import torch
5
5
 
6
+ import foscat.BkBase as BackendBase
7
+
8
+
6
9
  class BkTorch(BackendBase.BackendBase):
7
-
10
+
8
11
  def __init__(self, *args, **kwargs):
9
12
  # Impose que use_2D=True pour la classe scat
10
- super().__init__(name='torch', *args, **kwargs)
13
+ super().__init__(name="torch", *args, **kwargs)
11
14
  self.backend = torch
12
- self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
15
+ self.device = (
16
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
17
+ )
13
18
 
14
19
  self.float64 = self.backend.float64
15
20
  self.float32 = self.backend.float32
@@ -22,11 +27,13 @@ class BkTorch(BackendBase.BackendBase):
22
27
  "float32": (self.backend.float32, self.backend.complex64),
23
28
  "float64": (self.backend.float64, self.backend.complex128),
24
29
  }
25
-
30
+
26
31
  if self.all_type in dtype_map:
27
32
  self.all_bk_type, self.all_cbk_type = dtype_map[self.all_type]
28
33
  else:
29
- raise ValueError(f"ERROR INIT foscat: {all_type} should be float32 or float64")
34
+ raise ValueError(
35
+ f"ERROR INIT foscat: {self.all_type} should be float32 or float64"
36
+ )
30
37
 
31
38
  # ===========================================================================
32
39
  # INIT
@@ -50,14 +57,88 @@ class BkTorch(BackendBase.BackendBase):
50
57
  except RuntimeError as e:
51
58
  # Memory growth must be set before GPUs have been initialized
52
59
  print(e)
53
-
54
- self.torch_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
60
+
61
+ self.torch_device = (
62
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
63
+ )
64
+
65
+ def binned_mean(self, data, cell_ids):
66
+ """
67
+ data: Tensor of shape [B, N, A]
68
+ I: Tensor of shape [N], integer indices in [0, n_bins)
69
+ Returns: mean per bin, shape [B, n_bins, A]
70
+ """
71
+ groups = cell_ids // 4 # [N]
72
+
73
+ unique_groups, I = np.unique(groups, return_inverse=True)
74
+
75
+ n_bins = unique_groups.shape[0]
76
+
77
+ B = data.shape[0]
78
+
79
+ counts = torch.bincount(torch.tensor(I).to(data.device))[None, :]
80
+
81
+ I = np.tile(I, B) + np.tile(n_bins * np.arange(B, dtype="int"), data.shape[1])
82
+
83
+ if len(data.shape) == 3:
84
+ A = data.shape[2]
85
+ I = np.repeat(I, A) * A + np.repeat(
86
+ np.arange(A, dtype="int"), data.shape[1] * B
87
+ )
88
+
89
+ I = torch.tensor(I).to(data.device)
90
+
91
+ # Comptage par bin
92
+ if len(data.shape) == 2:
93
+ sum_per_bin = torch.zeros(
94
+ [B * n_bins], dtype=data.dtype, device=data.device
95
+ )
96
+ sum_per_bin = sum_per_bin.scatter_add(
97
+ 0, I, self.bk_reshape(data, B * data.shape[1])
98
+ ).reshape(B, n_bins)
99
+
100
+ mean_per_bin = sum_per_bin / counts # [B, n_bins, A]
101
+ else:
102
+ sum_per_bin = torch.zeros(
103
+ [B * n_bins * A], dtype=data.dtype, device=data.device
104
+ )
105
+ sum_per_bin = sum_per_bin.scatter_add(
106
+ 0, I, self.bk_reshape(data, B * data.shape[1] * A)
107
+ ).reshape(
108
+ B, n_bins, A
109
+ ) # [B, n_bins]
110
+
111
+ mean_per_bin = sum_per_bin / counts[:, :, None] # [B, n_bins, A]
112
+
113
+ return mean_per_bin, unique_groups
114
+
115
+ def average_by_cell_group(data, cell_ids):
116
+ """
117
+ data: tensor of shape [..., N, ...] (ex: [B, N, C])
118
+ cell_ids: tensor of shape [N]
119
+ Returns: mean_data of shape [..., G, ...] where G = number of unique cell_ids//4
120
+ """
121
+ original_shape = data.shape
122
+ leading = data.shape[:-2] # all dims before N
123
+ N = data.shape[-2]
124
+ trailing = data.shape[-1:] # all dims after N
125
+
126
+ groups = (cell_ids // 4).long() # [N]
127
+ unique_groups, group_indices, counts = torch.unique(
128
+ groups, return_inverse=True, return_counts=True
129
+ )
130
+
131
+ return torch.bincount(group_indices, weights=data) / counts, unique_groups
55
132
 
56
133
  # ---------------------------------------------−---------
57
134
  # -- BACKEND DEFINITION --
58
135
  # ---------------------------------------------−---------
59
136
  def bk_SparseTensor(self, indice, w, dense_shape=[]):
60
- return self.backend.sparse_coo_tensor(indice.T, w, dense_shape).to_sparse_csr().to(self.torch_device)
137
+ return (
138
+ self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
139
+ .to_sparse_csr()
140
+ .to(self.torch_device)
141
+ )
61
142
 
62
143
  def bk_stack(self, list, axis=0):
63
144
  return self.backend.stack(list, axis=axis).to(self.torch_device)
@@ -65,21 +146,20 @@ class BkTorch(BackendBase.BackendBase):
65
146
  def bk_sparse_dense_matmul(self, smat, mat):
66
147
  return smat.matmul(mat)
67
148
 
68
-
69
-
70
149
  def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
71
150
  import torch.nn.functional as F
151
+
72
152
  lx = x.permute(0, 3, 1, 2)
73
153
  wx = w.permute(3, 2, 0, 1) # de (5, 5, 1, 4) à (4, 1, 5, 5)
74
154
 
75
155
  # Calculer le padding symétrique
76
156
  kx, ky = w.shape[0], w.shape[1]
77
-
157
+
78
158
  # Appliquer le padding
79
- x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode='circular')
159
+ x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode="circular")
80
160
 
81
161
  # Appliquer la convolution
82
- return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0,2,3,1)
162
+ return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0, 2, 3, 1)
83
163
 
84
164
  def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
85
165
  # to be written!!!
@@ -108,7 +188,7 @@ class BkTorch(BackendBase.BackendBase):
108
188
  def bk_flattenR(self, x):
109
189
  if self.bk_is_complex(x):
110
190
  rr = self.backend.reshape(
111
- self.bk_real(x), [np.prod(np.array(list(x.shape)))]
191
+ self.bk_real(x), [np.prod(np.array(list(x.shape)))]
112
192
  )
113
193
  ii = self.backend.reshape(
114
194
  self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
@@ -122,18 +202,18 @@ class BkTorch(BackendBase.BackendBase):
122
202
 
123
203
  def bk_resize_image(self, x, shape):
124
204
  tmp = self.backend.nn.functional.interpolate(
125
- x.permute(0,3,1,2), size=shape, mode="bilinear", align_corners=False
126
- )
127
- return self.bk_cast(tmp.permute(0,2,3,1))
205
+ x.permute(0, 3, 1, 2), size=shape, mode="bilinear", align_corners=False
206
+ )
207
+ return self.bk_cast(tmp.permute(0, 2, 3, 1))
128
208
 
129
209
  def bk_L1(self, x):
130
210
  if x.dtype == self.all_cbk_type:
131
211
  xr = self.bk_real(x)
132
- xi = self.bk_imag(x)
212
+ # xi = self.bk_imag(x)
133
213
 
134
214
  r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
135
215
  # return r
136
- i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
216
+ # i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
137
217
 
138
218
  return r
139
219
  else:
@@ -163,7 +243,6 @@ class BkTorch(BackendBase.BackendBase):
163
243
  def bk_size(self, data):
164
244
  return data.numel()
165
245
 
166
-
167
246
  def constant(self, data):
168
247
  return data
169
248
 
@@ -194,7 +273,7 @@ class BkTorch(BackendBase.BackendBase):
194
273
  r = self.backend.std(data)
195
274
  else:
196
275
  r = self.backend.std(data, axis)
197
-
276
+
198
277
  if self.bk_is_complex(data):
199
278
  return self.bk_complex(r, 0 * r)
200
279
  else:
@@ -245,7 +324,7 @@ class BkTorch(BackendBase.BackendBase):
245
324
 
246
325
  def bk_tensor(self, data):
247
326
  return self.backend.constant(data).to(self.torch_device)
248
-
327
+
249
328
  def bk_shape_tensor(self, shape):
250
329
  return self.backend.tensor(shape=shape).to(self.torch_device)
251
330
 
@@ -304,12 +383,8 @@ class BkTorch(BackendBase.BackendBase):
304
383
  if axis is None:
305
384
  if data[0].dtype == self.all_cbk_type:
306
385
  ndata = len(data)
307
- xr = self.backend.concat(
308
- [self.bk_real(data[k]) for k in range(ndata)]
309
- )
310
- xi = self.backend.concat(
311
- [self.bk_imag(data[k]) for k in range(ndata)]
312
- )
386
+ xr = self.backend.concat([self.bk_real(data[k]) for k in range(ndata)])
387
+ xi = self.backend.concat([self.bk_imag(data[k]) for k in range(ndata)])
313
388
  return self.bk_complex(xr, xi)
314
389
  else:
315
390
  return self.backend.concat(data)
@@ -317,11 +392,11 @@ class BkTorch(BackendBase.BackendBase):
317
392
  if data[0].dtype == self.all_cbk_type:
318
393
  ndata = len(data)
319
394
  xr = self.backend.concat(
320
- [self.bk_real(data[k]) for k in range(ndata)], axis=axis
321
- )
395
+ [self.bk_real(data[k]) for k in range(ndata)], axis=axis
396
+ )
322
397
  xi = self.backend.concat(
323
- [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
324
- )
398
+ [self.bk_imag(data[k]) for k in range(ndata)], axis=axis
399
+ )
325
400
  return self.bk_complex(xr, xi)
326
401
  else:
327
402
  return self.backend.concat(data, axis=axis)
@@ -329,28 +404,28 @@ class BkTorch(BackendBase.BackendBase):
329
404
  def bk_zeros(self, shape, dtype=None):
330
405
  return self.backend.zeros(shape, dtype=dtype).to(self.torch_device)
331
406
 
332
- def bk_gather(self, data, idx,axis=0):
333
- if axis==0:
407
+ def bk_gather(self, data, idx, axis=0):
408
+ if axis == 0:
334
409
  return data[idx]
335
- elif axis==1:
336
- return data[:,idx]
337
- elif axis==2:
338
- return data[:,:,idx]
339
- elif axis==3:
340
- return data[:,:,:,idx]
341
- return data[:,:,:,:,idx]
410
+ elif axis == 1:
411
+ return data[:, idx]
412
+ elif axis == 2:
413
+ return data[:, :, idx]
414
+ elif axis == 3:
415
+ return data[:, :, :, idx]
416
+ return data[:, :, :, :, idx]
342
417
 
343
418
  def bk_reverse(self, data, axis=0):
344
419
  return self.backend.flip(data, dims=[axis])
345
420
 
346
421
  def bk_fft(self, data):
347
422
  return self.backend.fft.fft(data)
348
-
349
- def bk_fftn(self, data,dim=None):
350
- return self.backend.fft.fftn(data,dim=dim)
351
423
 
352
- def bk_ifftn(self, data,dim=None,norm=None):
353
- return self.backend.fft.ifftn(data,dim=dim,norm=norm)
424
+ def bk_fftn(self, data, dim=None):
425
+ return self.backend.fft.fftn(data, dim=dim)
426
+
427
+ def bk_ifftn(self, data, dim=None, norm=None):
428
+ return self.backend.fft.ifftn(data, dim=dim, norm=norm)
354
429
 
355
430
  def bk_rfft(self, data):
356
431
  return self.backend.fft.rfft(data)
@@ -374,12 +449,24 @@ class BkTorch(BackendBase.BackendBase):
374
449
  def bk_relu(self, x):
375
450
  return self.backend.relu(x)
376
451
 
377
- def bk_clip_by_value(self, x,xmin,xmax):
452
+ def bk_clip_by_value(self, x, xmin, xmax):
378
453
  if isinstance(x, np.ndarray):
379
- x = np.clip(x,xmin,xmax)
380
- x = self.backend.tensor(x, dtype=self.backend.float32) if not isinstance(x, self.backend.Tensor) else x
381
- xmin = self.backend.tensor(xmin, dtype=self.backend.float32) if not isinstance(xmin, self.backend.Tensor) else xmin
382
- xmax = self.backend.tensor(xmax, dtype=self.backend.float32) if not isinstance(xmax, self.backend.Tensor) else xmax
454
+ x = np.clip(x, xmin, xmax)
455
+ x = (
456
+ self.backend.tensor(x, dtype=self.backend.float32)
457
+ if not isinstance(x, self.backend.Tensor)
458
+ else x
459
+ )
460
+ xmin = (
461
+ self.backend.tensor(xmin, dtype=self.backend.float32)
462
+ if not isinstance(xmin, self.backend.Tensor)
463
+ else xmin
464
+ )
465
+ xmax = (
466
+ self.backend.tensor(xmax, dtype=self.backend.float32)
467
+ if not isinstance(xmax, self.backend.Tensor)
468
+ else xmax
469
+ )
383
470
  return self.backend.clamp(x, min=xmin, max=xmax)
384
471
 
385
472
  def bk_cast(self, x):
@@ -424,31 +511,31 @@ class BkTorch(BackendBase.BackendBase):
424
511
  out_type = self.all_bk_type
425
512
 
426
513
  return x.type(out_type).to(self.torch_device)
427
-
428
- def bk_variable(self,x):
514
+
515
+ def bk_variable(self, x):
429
516
  return self.bk_cast(x)
430
-
431
- def bk_assign(self,x,y):
432
- x=y
433
-
434
- def bk_constant(self,x):
435
-
517
+
518
+ def bk_assign(self, x, y):
519
+ return y
520
+
521
+ def bk_constant(self, x):
522
+
436
523
  return self.bk_cast(x)
437
-
438
- def bk_cos(self,x):
524
+
525
+ def bk_cos(self, x):
439
526
  return self.backend.cos(x)
440
-
441
- def bk_sin(self,x):
527
+
528
+ def bk_sin(self, x):
442
529
  return self.backend.sin(x)
443
-
444
- def bk_arctan2(self,c,s):
445
- return self.backend.arctan2(c,s)
446
-
447
- def bk_empty(self,list):
530
+
531
+ def bk_arctan2(self, c, s):
532
+ return self.backend.arctan2(c, s)
533
+
534
+ def bk_empty(self, list):
448
535
  return self.backend.empty(list)
449
-
450
- def to_numpy(self,x):
536
+
537
+ def to_numpy(self, x):
451
538
  if isinstance(x, np.ndarray):
452
539
  return x
453
-
540
+
454
541
  return x.cpu().numpy()