foscat 3.8.2__py3-none-any.whl → 2025.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkBase.py +36 -35
- foscat/BkNumpy.py +53 -62
- foscat/BkTensorflow.py +87 -88
- foscat/BkTorch.py +159 -72
- foscat/FoCUS.py +228 -89
- foscat/Synthesis.py +3 -3
- foscat/alm.py +188 -170
- foscat/backend.py +84 -70
- foscat/scat_cov.py +2138 -2220
- foscat/scat_cov2D.py +146 -53
- {foscat-3.8.2.dist-info → foscat-2025.3.0.dist-info}/METADATA +3 -2
- foscat-2025.3.0.dist-info/RECORD +30 -0
- {foscat-3.8.2.dist-info → foscat-2025.3.0.dist-info}/WHEEL +1 -1
- foscat-3.8.2.dist-info/RECORD +0 -30
- {foscat-3.8.2.dist-info → foscat-2025.3.0.dist-info/licenses}/LICENSE +0 -0
- {foscat-3.8.2.dist-info → foscat-2025.3.0.dist-info}/top_level.txt +0 -0
foscat/BkTorch.py
CHANGED
|
@@ -1,15 +1,20 @@
|
|
|
1
1
|
import sys
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
+
import foscat.BkBase as BackendBase
|
|
7
|
+
|
|
8
|
+
|
|
6
9
|
class BkTorch(BackendBase.BackendBase):
|
|
7
|
-
|
|
10
|
+
|
|
8
11
|
def __init__(self, *args, **kwargs):
|
|
9
12
|
# Impose que use_2D=True pour la classe scat
|
|
10
|
-
super().__init__(name=
|
|
13
|
+
super().__init__(name="torch", *args, **kwargs)
|
|
11
14
|
self.backend = torch
|
|
12
|
-
self.device =
|
|
15
|
+
self.device = (
|
|
16
|
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
17
|
+
)
|
|
13
18
|
|
|
14
19
|
self.float64 = self.backend.float64
|
|
15
20
|
self.float32 = self.backend.float32
|
|
@@ -22,11 +27,13 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
22
27
|
"float32": (self.backend.float32, self.backend.complex64),
|
|
23
28
|
"float64": (self.backend.float64, self.backend.complex128),
|
|
24
29
|
}
|
|
25
|
-
|
|
30
|
+
|
|
26
31
|
if self.all_type in dtype_map:
|
|
27
32
|
self.all_bk_type, self.all_cbk_type = dtype_map[self.all_type]
|
|
28
33
|
else:
|
|
29
|
-
raise ValueError(
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"ERROR INIT foscat: {self.all_type} should be float32 or float64"
|
|
36
|
+
)
|
|
30
37
|
|
|
31
38
|
# ===========================================================================
|
|
32
39
|
# INIT
|
|
@@ -50,14 +57,88 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
50
57
|
except RuntimeError as e:
|
|
51
58
|
# Memory growth must be set before GPUs have been initialized
|
|
52
59
|
print(e)
|
|
53
|
-
|
|
54
|
-
self.torch_device =
|
|
60
|
+
|
|
61
|
+
self.torch_device = (
|
|
62
|
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def binned_mean(self, data, cell_ids):
|
|
66
|
+
"""
|
|
67
|
+
data: Tensor of shape [B, N, A]
|
|
68
|
+
I: Tensor of shape [N], integer indices in [0, n_bins)
|
|
69
|
+
Returns: mean per bin, shape [B, n_bins, A]
|
|
70
|
+
"""
|
|
71
|
+
groups = cell_ids // 4 # [N]
|
|
72
|
+
|
|
73
|
+
unique_groups, I = np.unique(groups, return_inverse=True)
|
|
74
|
+
|
|
75
|
+
n_bins = unique_groups.shape[0]
|
|
76
|
+
|
|
77
|
+
B = data.shape[0]
|
|
78
|
+
|
|
79
|
+
counts = torch.bincount(torch.tensor(I).to(data.device))[None, :]
|
|
80
|
+
|
|
81
|
+
I = np.tile(I, B) + np.tile(n_bins * np.arange(B, dtype="int"), data.shape[1])
|
|
82
|
+
|
|
83
|
+
if len(data.shape) == 3:
|
|
84
|
+
A = data.shape[2]
|
|
85
|
+
I = np.repeat(I, A) * A + np.repeat(
|
|
86
|
+
np.arange(A, dtype="int"), data.shape[1] * B
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
I = torch.tensor(I).to(data.device)
|
|
90
|
+
|
|
91
|
+
# Comptage par bin
|
|
92
|
+
if len(data.shape) == 2:
|
|
93
|
+
sum_per_bin = torch.zeros(
|
|
94
|
+
[B * n_bins], dtype=data.dtype, device=data.device
|
|
95
|
+
)
|
|
96
|
+
sum_per_bin = sum_per_bin.scatter_add(
|
|
97
|
+
0, I, self.bk_reshape(data, B * data.shape[1])
|
|
98
|
+
).reshape(B, n_bins)
|
|
99
|
+
|
|
100
|
+
mean_per_bin = sum_per_bin / counts # [B, n_bins, A]
|
|
101
|
+
else:
|
|
102
|
+
sum_per_bin = torch.zeros(
|
|
103
|
+
[B * n_bins * A], dtype=data.dtype, device=data.device
|
|
104
|
+
)
|
|
105
|
+
sum_per_bin = sum_per_bin.scatter_add(
|
|
106
|
+
0, I, self.bk_reshape(data, B * data.shape[1] * A)
|
|
107
|
+
).reshape(
|
|
108
|
+
B, n_bins, A
|
|
109
|
+
) # [B, n_bins]
|
|
110
|
+
|
|
111
|
+
mean_per_bin = sum_per_bin / counts[:, :, None] # [B, n_bins, A]
|
|
112
|
+
|
|
113
|
+
return mean_per_bin, unique_groups
|
|
114
|
+
|
|
115
|
+
def average_by_cell_group(data, cell_ids):
|
|
116
|
+
"""
|
|
117
|
+
data: tensor of shape [..., N, ...] (ex: [B, N, C])
|
|
118
|
+
cell_ids: tensor of shape [N]
|
|
119
|
+
Returns: mean_data of shape [..., G, ...] where G = number of unique cell_ids//4
|
|
120
|
+
"""
|
|
121
|
+
original_shape = data.shape
|
|
122
|
+
leading = data.shape[:-2] # all dims before N
|
|
123
|
+
N = data.shape[-2]
|
|
124
|
+
trailing = data.shape[-1:] # all dims after N
|
|
125
|
+
|
|
126
|
+
groups = (cell_ids // 4).long() # [N]
|
|
127
|
+
unique_groups, group_indices, counts = torch.unique(
|
|
128
|
+
groups, return_inverse=True, return_counts=True
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return torch.bincount(group_indices, weights=data) / counts, unique_groups
|
|
55
132
|
|
|
56
133
|
# ---------------------------------------------−---------
|
|
57
134
|
# -- BACKEND DEFINITION --
|
|
58
135
|
# ---------------------------------------------−---------
|
|
59
136
|
def bk_SparseTensor(self, indice, w, dense_shape=[]):
|
|
60
|
-
return
|
|
137
|
+
return (
|
|
138
|
+
self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
|
|
139
|
+
.to_sparse_csr()
|
|
140
|
+
.to(self.torch_device)
|
|
141
|
+
)
|
|
61
142
|
|
|
62
143
|
def bk_stack(self, list, axis=0):
|
|
63
144
|
return self.backend.stack(list, axis=axis).to(self.torch_device)
|
|
@@ -65,21 +146,20 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
65
146
|
def bk_sparse_dense_matmul(self, smat, mat):
|
|
66
147
|
return smat.matmul(mat)
|
|
67
148
|
|
|
68
|
-
|
|
69
|
-
|
|
70
149
|
def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
|
|
71
150
|
import torch.nn.functional as F
|
|
151
|
+
|
|
72
152
|
lx = x.permute(0, 3, 1, 2)
|
|
73
153
|
wx = w.permute(3, 2, 0, 1) # de (5, 5, 1, 4) à (4, 1, 5, 5)
|
|
74
154
|
|
|
75
155
|
# Calculer le padding symétrique
|
|
76
156
|
kx, ky = w.shape[0], w.shape[1]
|
|
77
|
-
|
|
157
|
+
|
|
78
158
|
# Appliquer le padding
|
|
79
|
-
x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode=
|
|
159
|
+
x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode="circular")
|
|
80
160
|
|
|
81
161
|
# Appliquer la convolution
|
|
82
|
-
return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0,2,3,1)
|
|
162
|
+
return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0, 2, 3, 1)
|
|
83
163
|
|
|
84
164
|
def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
|
|
85
165
|
# to be written!!!
|
|
@@ -108,7 +188,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
108
188
|
def bk_flattenR(self, x):
|
|
109
189
|
if self.bk_is_complex(x):
|
|
110
190
|
rr = self.backend.reshape(
|
|
111
|
-
|
|
191
|
+
self.bk_real(x), [np.prod(np.array(list(x.shape)))]
|
|
112
192
|
)
|
|
113
193
|
ii = self.backend.reshape(
|
|
114
194
|
self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
|
|
@@ -122,18 +202,18 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
122
202
|
|
|
123
203
|
def bk_resize_image(self, x, shape):
|
|
124
204
|
tmp = self.backend.nn.functional.interpolate(
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
return self.bk_cast(tmp.permute(0,2,3,1))
|
|
205
|
+
x.permute(0, 3, 1, 2), size=shape, mode="bilinear", align_corners=False
|
|
206
|
+
)
|
|
207
|
+
return self.bk_cast(tmp.permute(0, 2, 3, 1))
|
|
128
208
|
|
|
129
209
|
def bk_L1(self, x):
|
|
130
210
|
if x.dtype == self.all_cbk_type:
|
|
131
211
|
xr = self.bk_real(x)
|
|
132
|
-
xi = self.bk_imag(x)
|
|
212
|
+
# xi = self.bk_imag(x)
|
|
133
213
|
|
|
134
214
|
r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
|
|
135
215
|
# return r
|
|
136
|
-
i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
|
|
216
|
+
# i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
|
|
137
217
|
|
|
138
218
|
return r
|
|
139
219
|
else:
|
|
@@ -163,7 +243,6 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
163
243
|
def bk_size(self, data):
|
|
164
244
|
return data.numel()
|
|
165
245
|
|
|
166
|
-
|
|
167
246
|
def constant(self, data):
|
|
168
247
|
return data
|
|
169
248
|
|
|
@@ -194,7 +273,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
194
273
|
r = self.backend.std(data)
|
|
195
274
|
else:
|
|
196
275
|
r = self.backend.std(data, axis)
|
|
197
|
-
|
|
276
|
+
|
|
198
277
|
if self.bk_is_complex(data):
|
|
199
278
|
return self.bk_complex(r, 0 * r)
|
|
200
279
|
else:
|
|
@@ -245,7 +324,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
245
324
|
|
|
246
325
|
def bk_tensor(self, data):
|
|
247
326
|
return self.backend.constant(data).to(self.torch_device)
|
|
248
|
-
|
|
327
|
+
|
|
249
328
|
def bk_shape_tensor(self, shape):
|
|
250
329
|
return self.backend.tensor(shape=shape).to(self.torch_device)
|
|
251
330
|
|
|
@@ -304,12 +383,8 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
304
383
|
if axis is None:
|
|
305
384
|
if data[0].dtype == self.all_cbk_type:
|
|
306
385
|
ndata = len(data)
|
|
307
|
-
xr = self.backend.concat(
|
|
308
|
-
|
|
309
|
-
)
|
|
310
|
-
xi = self.backend.concat(
|
|
311
|
-
[self.bk_imag(data[k]) for k in range(ndata)]
|
|
312
|
-
)
|
|
386
|
+
xr = self.backend.concat([self.bk_real(data[k]) for k in range(ndata)])
|
|
387
|
+
xi = self.backend.concat([self.bk_imag(data[k]) for k in range(ndata)])
|
|
313
388
|
return self.bk_complex(xr, xi)
|
|
314
389
|
else:
|
|
315
390
|
return self.backend.concat(data)
|
|
@@ -317,11 +392,11 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
317
392
|
if data[0].dtype == self.all_cbk_type:
|
|
318
393
|
ndata = len(data)
|
|
319
394
|
xr = self.backend.concat(
|
|
320
|
-
|
|
321
|
-
|
|
395
|
+
[self.bk_real(data[k]) for k in range(ndata)], axis=axis
|
|
396
|
+
)
|
|
322
397
|
xi = self.backend.concat(
|
|
323
|
-
|
|
324
|
-
|
|
398
|
+
[self.bk_imag(data[k]) for k in range(ndata)], axis=axis
|
|
399
|
+
)
|
|
325
400
|
return self.bk_complex(xr, xi)
|
|
326
401
|
else:
|
|
327
402
|
return self.backend.concat(data, axis=axis)
|
|
@@ -329,28 +404,28 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
329
404
|
def bk_zeros(self, shape, dtype=None):
|
|
330
405
|
return self.backend.zeros(shape, dtype=dtype).to(self.torch_device)
|
|
331
406
|
|
|
332
|
-
def bk_gather(self, data, idx,axis=0):
|
|
333
|
-
if axis==0:
|
|
407
|
+
def bk_gather(self, data, idx, axis=0):
|
|
408
|
+
if axis == 0:
|
|
334
409
|
return data[idx]
|
|
335
|
-
elif axis==1:
|
|
336
|
-
return data[:,idx]
|
|
337
|
-
elif axis==2:
|
|
338
|
-
return data[
|
|
339
|
-
elif axis==3:
|
|
340
|
-
return data[
|
|
341
|
-
return data[
|
|
410
|
+
elif axis == 1:
|
|
411
|
+
return data[:, idx]
|
|
412
|
+
elif axis == 2:
|
|
413
|
+
return data[:, :, idx]
|
|
414
|
+
elif axis == 3:
|
|
415
|
+
return data[:, :, :, idx]
|
|
416
|
+
return data[:, :, :, :, idx]
|
|
342
417
|
|
|
343
418
|
def bk_reverse(self, data, axis=0):
|
|
344
419
|
return self.backend.flip(data, dims=[axis])
|
|
345
420
|
|
|
346
421
|
def bk_fft(self, data):
|
|
347
422
|
return self.backend.fft.fft(data)
|
|
348
|
-
|
|
349
|
-
def bk_fftn(self, data,dim=None):
|
|
350
|
-
return self.backend.fft.fftn(data,dim=dim)
|
|
351
423
|
|
|
352
|
-
def
|
|
353
|
-
return self.backend.fft.
|
|
424
|
+
def bk_fftn(self, data, dim=None):
|
|
425
|
+
return self.backend.fft.fftn(data, dim=dim)
|
|
426
|
+
|
|
427
|
+
def bk_ifftn(self, data, dim=None, norm=None):
|
|
428
|
+
return self.backend.fft.ifftn(data, dim=dim, norm=norm)
|
|
354
429
|
|
|
355
430
|
def bk_rfft(self, data):
|
|
356
431
|
return self.backend.fft.rfft(data)
|
|
@@ -374,12 +449,24 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
374
449
|
def bk_relu(self, x):
|
|
375
450
|
return self.backend.relu(x)
|
|
376
451
|
|
|
377
|
-
def bk_clip_by_value(self, x,xmin,xmax):
|
|
452
|
+
def bk_clip_by_value(self, x, xmin, xmax):
|
|
378
453
|
if isinstance(x, np.ndarray):
|
|
379
|
-
x = np.clip(x,xmin,xmax)
|
|
380
|
-
x =
|
|
381
|
-
|
|
382
|
-
|
|
454
|
+
x = np.clip(x, xmin, xmax)
|
|
455
|
+
x = (
|
|
456
|
+
self.backend.tensor(x, dtype=self.backend.float32)
|
|
457
|
+
if not isinstance(x, self.backend.Tensor)
|
|
458
|
+
else x
|
|
459
|
+
)
|
|
460
|
+
xmin = (
|
|
461
|
+
self.backend.tensor(xmin, dtype=self.backend.float32)
|
|
462
|
+
if not isinstance(xmin, self.backend.Tensor)
|
|
463
|
+
else xmin
|
|
464
|
+
)
|
|
465
|
+
xmax = (
|
|
466
|
+
self.backend.tensor(xmax, dtype=self.backend.float32)
|
|
467
|
+
if not isinstance(xmax, self.backend.Tensor)
|
|
468
|
+
else xmax
|
|
469
|
+
)
|
|
383
470
|
return self.backend.clamp(x, min=xmin, max=xmax)
|
|
384
471
|
|
|
385
472
|
def bk_cast(self, x):
|
|
@@ -424,31 +511,31 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
424
511
|
out_type = self.all_bk_type
|
|
425
512
|
|
|
426
513
|
return x.type(out_type).to(self.torch_device)
|
|
427
|
-
|
|
428
|
-
def bk_variable(self,x):
|
|
514
|
+
|
|
515
|
+
def bk_variable(self, x):
|
|
429
516
|
return self.bk_cast(x)
|
|
430
|
-
|
|
431
|
-
def bk_assign(self,x,y):
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
def bk_constant(self,x):
|
|
435
|
-
|
|
517
|
+
|
|
518
|
+
def bk_assign(self, x, y):
|
|
519
|
+
return y
|
|
520
|
+
|
|
521
|
+
def bk_constant(self, x):
|
|
522
|
+
|
|
436
523
|
return self.bk_cast(x)
|
|
437
|
-
|
|
438
|
-
def bk_cos(self,x):
|
|
524
|
+
|
|
525
|
+
def bk_cos(self, x):
|
|
439
526
|
return self.backend.cos(x)
|
|
440
|
-
|
|
441
|
-
def bk_sin(self,x):
|
|
527
|
+
|
|
528
|
+
def bk_sin(self, x):
|
|
442
529
|
return self.backend.sin(x)
|
|
443
|
-
|
|
444
|
-
def bk_arctan2(self,c,s):
|
|
445
|
-
return self.backend.arctan2(c,s)
|
|
446
|
-
|
|
447
|
-
def bk_empty(self,list):
|
|
530
|
+
|
|
531
|
+
def bk_arctan2(self, c, s):
|
|
532
|
+
return self.backend.arctan2(c, s)
|
|
533
|
+
|
|
534
|
+
def bk_empty(self, list):
|
|
448
535
|
return self.backend.empty(list)
|
|
449
|
-
|
|
450
|
-
def to_numpy(self,x):
|
|
536
|
+
|
|
537
|
+
def to_numpy(self, x):
|
|
451
538
|
if isinstance(x, np.ndarray):
|
|
452
539
|
return x
|
|
453
|
-
|
|
540
|
+
|
|
454
541
|
return x.cpu().numpy()
|