AOT-biomaps 2.9.138__py3-none-any.whl → 2.9.279__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of AOT-biomaps might be problematic. Click here for more details.

Files changed (31) hide show
  1. AOT_biomaps/AOT_Acoustic/AcousticTools.py +35 -115
  2. AOT_biomaps/AOT_Acoustic/StructuredWave.py +2 -2
  3. AOT_biomaps/AOT_Acoustic/_mainAcoustic.py +22 -18
  4. AOT_biomaps/AOT_Experiment/Tomography.py +74 -4
  5. AOT_biomaps/AOT_Experiment/_mainExperiment.py +102 -68
  6. AOT_biomaps/AOT_Optic/_mainOptic.py +124 -58
  7. AOT_biomaps/AOT_Recon/AOT_Optimizers/DEPIERRO.py +72 -108
  8. AOT_biomaps/AOT_Recon/AOT_Optimizers/LS.py +474 -289
  9. AOT_biomaps/AOT_Recon/AOT_Optimizers/MAPEM.py +173 -68
  10. AOT_biomaps/AOT_Recon/AOT_Optimizers/MLEM.py +360 -154
  11. AOT_biomaps/AOT_Recon/AOT_Optimizers/PDHG.py +150 -111
  12. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/RelativeDifferences.py +10 -14
  13. AOT_biomaps/AOT_Recon/AOT_SparseSMatrix/SparseSMatrix_CSR.py +281 -0
  14. AOT_biomaps/AOT_Recon/AOT_SparseSMatrix/SparseSMatrix_SELL.py +328 -0
  15. AOT_biomaps/AOT_Recon/AOT_SparseSMatrix/__init__.py +2 -0
  16. AOT_biomaps/AOT_Recon/AOT_biomaps_kernels.cubin +0 -0
  17. AOT_biomaps/AOT_Recon/AlgebraicRecon.py +359 -238
  18. AOT_biomaps/AOT_Recon/AnalyticRecon.py +29 -41
  19. AOT_biomaps/AOT_Recon/BayesianRecon.py +165 -91
  20. AOT_biomaps/AOT_Recon/DeepLearningRecon.py +4 -1
  21. AOT_biomaps/AOT_Recon/PrimalDualRecon.py +175 -31
  22. AOT_biomaps/AOT_Recon/ReconEnums.py +38 -3
  23. AOT_biomaps/AOT_Recon/ReconTools.py +184 -77
  24. AOT_biomaps/AOT_Recon/__init__.py +1 -0
  25. AOT_biomaps/AOT_Recon/_mainRecon.py +144 -74
  26. AOT_biomaps/__init__.py +4 -36
  27. {aot_biomaps-2.9.138.dist-info → aot_biomaps-2.9.279.dist-info}/METADATA +2 -1
  28. aot_biomaps-2.9.279.dist-info/RECORD +47 -0
  29. aot_biomaps-2.9.138.dist-info/RECORD +0 -43
  30. {aot_biomaps-2.9.138.dist-info → aot_biomaps-2.9.279.dist-info}/WHEEL +0 -0
  31. {aot_biomaps-2.9.138.dist-info → aot_biomaps-2.9.279.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,86 @@
1
- from AOT_biomaps.AOT_Recon.ReconTools import _forward_projection, _backward_projection
1
+ from AOT_biomaps.AOT_Recon.ReconTools import _forward_projection, _backward_projection, check_gpu_memory, calculate_memory_requirement
2
2
  from AOT_biomaps.Config import config
3
-
3
+ from AOT_biomaps.AOT_Recon.AOT_SparseSMatrix.SparseSMatrix_SELL import SparseSMatrix_SELL
4
+ from AOT_biomaps.AOT_Recon.AOT_SparseSMatrix.SparseSMatrix_CSR import SparseSMatrix_CSR
5
+ from AOT_biomaps.AOT_Recon.ReconEnums import SMatrixType
4
6
  import numba
5
7
  import torch
6
8
  import numpy as np
7
9
  import os
8
10
  from tqdm import trange
9
-
10
-
11
- def _MLEM_GPU_basic(SMatrix, y, numIterations, isSavingEachIteration, withTumor):
11
+ import cupy as cp
12
+ import cupyx.scipy.sparse as cpsparse
13
+ import gc
14
+ import pycuda.driver as drv
15
+
16
+
17
+ def MLEM(
18
+ SMatrix,
19
+ y,
20
+ numIterations=100,
21
+ isSavingEachIteration=True,
22
+ withTumor=True,
23
+ device=None,
24
+ use_numba=False,
25
+ denominator_threshold=1e-6,
26
+ max_saves=5000,
27
+ show_logs=True,
28
+ smatrixType=SMatrixType.SELL,
29
+ Z=350,
30
+ ):
12
31
  """
13
- Optimized MLEM using GPU saves at most 1000 images, equally spaced.
14
- Returns both the images and their iteration indices.
32
+ Unified MLEM algorithm for Acousto-Optic Tomography.
33
+ Works on CPU (basic, multithread, optimized) and GPU (single or multi-GPU).
34
+ Args:
35
+ SMatrix: System matrix (shape: T, Z, X, N)
36
+ y: Measurement data (shape: T, N)
37
+ numIterations: Number of iterations
38
+ isSavingEachIteration: If True, saves intermediate results
39
+ withTumor: Boolean for description only
40
+ device: Torch device (auto-selected if None)
41
+ use_multi_gpu: If True and GPU available, uses all GPUs
42
+ use_numba: If True and on CPU, uses multithreaded Numba
43
+ max_saves: Maximum number of intermediate saves (default: 5000)
44
+ Returns:
45
+ Reconstructed image(s) and iteration indices (if isSavingEachIteration)
15
46
  """
47
+ # try:
48
+ tumor_str = "WITH" if withTumor else "WITHOUT"
49
+ # Auto-select device and method
50
+ if device is None:
51
+ if torch.cuda.is_available() and check_gpu_memory(config.select_best_gpu(), calculate_memory_requirement(SMatrix, y), show_logs=show_logs):
52
+ device = torch.device(f"cuda:{config.select_best_gpu()}")
53
+ use_gpu = True
54
+ else:
55
+ device = torch.device("cpu")
56
+ use_gpu = False
57
+ else:
58
+ use_gpu = device.type == "cuda"
59
+ # Dispatch to the appropriate implementation
60
+ if use_gpu:
61
+ if smatrixType == SMatrixType.CSR:
62
+ return MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, Z, show_logs)
63
+ elif smatrixType == SMatrixType.SELL:
64
+ return MLEM_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs)
65
+ elif smatrixType == SMatrixType.DENSE:
66
+ return _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold,show_logs)
67
+ else:
68
+ raise ValueError("Unsupported SMatrixType for GPU MLEM.")
69
+ else:
70
+ if use_numba:
71
+ return _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs)
72
+ else:
73
+ return _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs)
74
+ # except Exception as e:
75
+ # print(f"Error in MLEM: {type(e).__name__}: {e}")
76
+ # return None, None
77
+
78
+ def _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
16
79
  try:
17
- device = torch.device(f"cuda:{config.select_best_gpu()}")
18
80
  eps = torch.finfo(torch.float32).eps
19
81
  T, Z, X, N = SMatrix.shape
20
82
  ZX = Z * X
21
83
  TN = T * N
22
-
23
- # Conversion to flattened matrix
24
84
  A_flat = (
25
85
  torch.from_numpy(SMatrix)
26
86
  .to(device=device, dtype=torch.float32)
@@ -36,114 +96,78 @@ def _MLEM_GPU_basic(SMatrix, y, numIterations, isSavingEachIteration, withTumor)
36
96
  .sum(dim=(0, 3))
37
97
  .reshape(-1)
38
98
  )
39
-
40
- description = f"AOT-BioMaps -- ML-EM ---- {'WITH' if withTumor else 'WITHOUT'} TUMOR ---- GPU {torch.cuda.current_device()}"
99
+ description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
100
+ # Calculate save indices
101
+ if numIterations <= max_saves:
102
+ save_indices = list(range(numIterations))
103
+ else:
104
+ step = numIterations // max_saves
105
+ save_indices = list(range(0, numIterations, step))
106
+ if save_indices[-1] != numIterations - 1:
107
+ save_indices.append(numIterations - 1)
41
108
  saved_theta = []
42
109
  saved_indices = []
43
-
44
- if isSavingEachIteration:
45
- # Calculate step to save at most 1000 images
46
- step = max(1, numIterations // 1000)
47
- save_count = 0
48
-
49
110
  with torch.no_grad():
50
- for it in trange(numIterations, desc=description):
111
+ # Utilise range si show_logs=False, sinon trange
112
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
113
+ for it in iterator:
51
114
  q_flat = A_flat @ theta_flat
52
- e_flat = y_flat / (q_flat + eps)
115
+ # Appliquer le seuil : si q_flat < denominator_threshold, on met e_flat à 1 (comme dans le code C++)
116
+ mask = q_flat >= denominator_threshold
117
+ e_flat = torch.where(mask, y_flat / (q_flat + eps), torch.ones_like(q_flat))
53
118
  c_flat = A_flat.T @ e_flat
54
119
  theta_flat = (theta_flat / (norm_factor_flat + eps)) * c_flat
55
-
56
- if isSavingEachIteration and (it % step == 0 or it == numIterations - 1):
120
+ if isSavingEachIteration and it in save_indices:
57
121
  saved_theta.append(theta_flat.reshape(Z, X).clone())
58
122
  saved_indices.append(it)
59
- save_count += 1
60
- if save_count >= 1000: # Hard limit
61
- break
62
-
63
123
  # Free memory
64
124
  del A_flat, y_flat, norm_factor_flat
65
125
  torch.cuda.empty_cache()
66
-
67
126
  if isSavingEachIteration:
68
127
  return [t.cpu().numpy() for t in saved_theta], saved_indices
69
128
  else:
70
129
  return theta_flat.reshape(Z, X).cpu().numpy(), None
71
-
72
130
  except Exception as e:
73
- print("Error in MLEM:", type(e).__name__, ":", e)
131
+ print(f"Error in single-GPU MLEM: {type(e).__name__}: {e}")
74
132
  torch.cuda.empty_cache()
75
- return None, []
76
-
77
- def _MLEM_CPU_basic(SMatrix, y, numIterations, isSavingEachIteration, withTumor):
78
- try:
79
- q_p = np.zeros((SMatrix.shape[0], SMatrix.shape[3]))
80
- c_p = np.zeros((SMatrix.shape[1], SMatrix.shape[2]))
81
- theta_p_0 = np.ones((SMatrix.shape[1], SMatrix.shape[2]))
82
- matrix_theta = [theta_p_0]
83
- saved_indices = [0]
84
- normalization_factor = np.sum(SMatrix, axis=(0, 3))
85
- description = f"AOT-BioMaps -- ML-EM ---- {'WITH' if withTumor else 'WITHOUT'} TUMOR ---- processing on single CPU (basic) ----"
86
-
87
- if isSavingEachIteration:
88
- step = max(1, (numIterations - 1) // 999)
89
- save_count = 1
90
-
91
- for it in trange(numIterations, desc=description):
92
- theta_p = matrix_theta[-1]
93
- for _t in range(SMatrix.shape[0]):
94
- for _n in range(SMatrix.shape[3]):
95
- q_p[_t, _n] = np.sum(SMatrix[_t, :, :, _n] * theta_p)
96
- e_p = y / (q_p + 1e-8)
97
- for _z in range(SMatrix.shape[1]):
98
- for _x in range(SMatrix.shape[2]):
99
- c_p[_z, _x] = np.sum(SMatrix[:, _z, _x, :] * e_p)
100
- theta_p_plus_1 = theta_p / (normalization_factor + 1e-8) * c_p
101
-
102
- if isSavingEachIteration and (it % step == 0 or it == numIterations - 1):
103
- matrix_theta.append(theta_p_plus_1)
104
- saved_indices.append(it + 1)
105
- save_count += 1
106
- if save_count >= 1000:
107
- break
108
- else:
109
- matrix_theta[-1] = theta_p_plus_1
110
-
111
- if not isSavingEachIteration:
112
- return matrix_theta[-1], None
113
- else:
114
- return matrix_theta, saved_indices
115
- except Exception as e:
116
- print("Error in basic CPU MLEM:", type(e).__name__, ":", e)
117
133
  return None, None
118
134
 
119
- def _MLEM_CPU_multi(SMatrix, y, numIterations, isSavingEachIteration, withTumor):
135
+ def _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs=True):
120
136
  try:
121
137
  numba.set_num_threads(os.cpu_count())
122
- q_p = np.zeros((SMatrix.shape[0], SMatrix.shape[3]))
123
- c_p = np.zeros((SMatrix.shape[1], SMatrix.shape[2]))
124
- theta_p_0 = np.ones((SMatrix.shape[1], SMatrix.shape[2]))
138
+ q_p = np.zeros((SMatrix.shape[0], SMatrix.shape[3]), dtype=np.float32)
139
+ c_p = np.zeros((SMatrix.shape[1], SMatrix.shape[2]), dtype=np.float32)
140
+ theta_p_0 = np.ones((SMatrix.shape[1], SMatrix.shape[2]), dtype=np.float32)
125
141
  matrix_theta = [theta_p_0]
126
142
  saved_indices = [0]
127
- normalization_factor = np.sum(SMatrix, axis=(0, 3))
128
- description = f"AOT-BioMaps -- ML-EM ---- {'WITH' if withTumor else 'WITHOUT'} TUMOR ---- processing on multithread CPU ({numba.config.NUMBA_DEFAULT_NUM_THREADS} threads)----"
143
+ normalization_factor = np.sum(SMatrix, axis=(0, 3)).astype(np.float32)
129
144
 
130
- if isSavingEachIteration:
131
- step = max(1, (numIterations - 1) // 999)
132
- save_count = 1
145
+ # Calculate save indices
146
+ if numIterations <= max_saves:
147
+ save_indices = list(range(numIterations))
148
+ else:
149
+ step = numIterations // max_saves
150
+ save_indices = list(range(0, numIterations, step))
151
+ if save_indices[-1] != numIterations - 1:
152
+ save_indices.append(numIterations - 1)
133
153
 
134
- for it in trange(numIterations, desc=description):
154
+ description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- processing on multithread CPU ({numba.config.NUMBA_DEFAULT_NUM_THREADS} threads) ----"
155
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
156
+
157
+ for it in iterator:
135
158
  theta_p = matrix_theta[-1]
136
159
  _forward_projection(SMatrix, theta_p, q_p)
137
- e_p = y / (q_p + 1e-8)
160
+
161
+ # Appliquer le seuil : si q_p < denominator_threshold, on met e_p à 1
162
+ mask = q_p >= denominator_threshold
163
+ e_p = np.where(mask, y / (q_p + 1e-8), 1.0)
164
+
138
165
  _backward_projection(SMatrix, e_p, c_p)
139
166
  theta_p_plus_1 = theta_p / (normalization_factor + 1e-8) * c_p
140
167
 
141
- if isSavingEachIteration and (it % step == 0 or it == numIterations - 1):
168
+ if isSavingEachIteration and (it + 1) in save_indices:
142
169
  matrix_theta.append(theta_p_plus_1)
143
170
  saved_indices.append(it + 1)
144
- save_count += 1
145
- if save_count >= 1000:
146
- break
147
171
  else:
148
172
  matrix_theta[-1] = theta_p_plus_1
149
173
 
@@ -152,10 +176,10 @@ def _MLEM_CPU_multi(SMatrix, y, numIterations, isSavingEachIteration, withTumor)
152
176
  else:
153
177
  return matrix_theta, saved_indices
154
178
  except Exception as e:
155
- print("Error in multi-threaded CPU MLEM:", type(e).__name__, ":", e)
179
+ print(f"Error in Numba CPU MLEM: {type(e).__name__}: {e}")
156
180
  return None, None
157
-
158
- def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, withTumor):
181
+
182
+ def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs=True):
159
183
  try:
160
184
  T, Z, X, N = SMatrix.shape
161
185
  A_flat = SMatrix.astype(np.float32).transpose(0, 3, 1, 2).reshape(T * N, Z * X)
@@ -165,27 +189,35 @@ def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, withTumor):
165
189
  saved_indices = [0]
166
190
  normalization_factor = np.sum(SMatrix, axis=(0, 3)).astype(np.float32)
167
191
  normalization_factor_flat = normalization_factor.reshape(-1)
168
- description = f"AOT-BioMaps -- ML-EM ---- {'WITH' if withTumor else 'WITHOUT'} TUMOR ---- processing on single CPU (optimized) ----"
169
192
 
170
- if isSavingEachIteration:
171
- step = max(1, (numIterations - 1) // 999)
172
- save_count = 1
193
+ # Calculate save indices
194
+ if numIterations <= max_saves:
195
+ save_indices = list(range(numIterations))
196
+ else:
197
+ step = numIterations // max_saves
198
+ save_indices = list(range(0, numIterations, step))
199
+ if save_indices[-1] != numIterations - 1:
200
+ save_indices.append(numIterations - 1)
201
+
202
+ description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- processing on single CPU (optimized) ----"
203
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
173
204
 
174
- for it in trange(numIterations, desc=description):
205
+ for it in iterator:
175
206
  theta_p = matrix_theta[-1]
176
207
  theta_p_flat = theta_p.reshape(-1)
177
208
  q_flat = A_flat @ theta_p_flat
178
- e_flat = y_flat / (q_flat + np.finfo(np.float32).tiny)
209
+
210
+ # Appliquer le seuil : si q_flat < denominator_threshold, on met e_flat à 1
211
+ mask = q_flat >= denominator_threshold
212
+ e_flat = np.where(mask, y_flat / (q_flat + np.finfo(np.float32).tiny), 1.0)
213
+
179
214
  c_flat = A_flat.T @ e_flat
180
215
  theta_p_plus_1_flat = theta_p_flat / (normalization_factor_flat + np.finfo(np.float32).tiny) * c_flat
181
216
  theta_p_plus_1 = theta_p_plus_1_flat.reshape(Z, X)
182
217
 
183
- if isSavingEachIteration and (it % step == 0 or it == numIterations - 1):
218
+ if isSavingEachIteration and (it + 1) in save_indices:
184
219
  matrix_theta.append(theta_p_plus_1)
185
220
  saved_indices.append(it + 1)
186
- save_count += 1
187
- if save_count >= 1000:
188
- break
189
221
  else:
190
222
  matrix_theta[-1] = theta_p_plus_1
191
223
 
@@ -194,67 +226,241 @@ def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, withTumor):
194
226
  else:
195
227
  return matrix_theta, saved_indices
196
228
  except Exception as e:
197
- print("Error in optimized CPU MLEM:", type(e).__name__, ":", e)
229
+ print(f"Error in optimized CPU MLEM: {type(e).__name__}: {e}")
198
230
  return None, None
231
+
232
+ def MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
233
+ """
234
+ SMatrix: instance of SparseMatrixGPU (already allocated)
235
+ y: measured data (1D np.float32 of length TN)
199
236
 
200
- def _MLEM_GPU_multi(SMatrix, y, numIterations, isSavingEachIteration, withTumor):
237
+ Assumptions:
238
+ - SMatrix.values_gpu and SMatrix.col_ind_gpu and SMatrix.row_ptr_gpu are device pointers
239
+ - SMatrix.norm_factor_inv_gpu exists
240
+ - SMatrix.ctx is the PyCUDA context for the target GPU.
241
+ """
242
+
243
+ # We use a final_result placeholder to ensure it's defined outside the try block
244
+ final_result = None
245
+
201
246
  try:
202
- num_gpus = torch.cuda.device_count()
203
- device = torch.device('cuda:0')
204
- T, Z, X, N = SMatrix.shape
205
- A_matrix_torch = torch.tensor(SMatrix, dtype=torch.float32).to(device).permute(0, 3, 1, 2).reshape(T * N, Z * X)
206
- y_torch = torch.tensor(y, dtype=torch.float32).to(device).reshape(-1)
207
- A_split = torch.chunk(A_matrix_torch, num_gpus, dim=0)
208
- y_split = torch.chunk(y_torch, num_gpus)
209
- theta_0 = torch.ones((Z, X), dtype=torch.float32, device=device)
210
- theta_list = [theta_0.clone().to(device) for _ in range(num_gpus)]
211
- normalization_factor = A_matrix_torch.sum(dim=0).reshape(Z, X).to(device)
212
- saved_theta = [theta_0.cpu().numpy()]
213
- saved_indices = [0]
214
- description = f"AOT-BioMaps -- ML-EM ---- {'WITH' if withTumor else 'WITHOUT'} TUMOR ---- processing on multi-GPU ----"
247
+ if not isinstance(SMatrix, SparseSMatrix_CSR):
248
+ raise TypeError("SMatrix must be a SparseSMatrix_CSR object")
249
+
250
+ # --- CONTEXT FIX: Push the context associated with SMatrix ---
251
+ # This ensures all subsequent PyCUDA operations use the correct GPU/context.
252
+ if SMatrix.ctx:
253
+ SMatrix.ctx.push()
254
+ # -----------------------------------------------------------
255
+
256
+ dtype = np.float32
257
+ TN = SMatrix.N * SMatrix.T
258
+ ZX = SMatrix.Z * SMatrix.X
259
+ # Ensure Z and X are correctly defined for reshaping
260
+ Z = SMatrix.Z
261
+ X = SMatrix.X
262
+
263
+ if show_logs:
264
+ # We assume SMatrix was initialized using the correct device index.
265
+ print(f"Executing on GPU device index: {SMatrix.device.primary_context.device.name()}")
266
+ print(f"Dim X: {X}, Dim Z: {Z}, TN: {TN}, ZX: {ZX}")
267
+
268
+ # streams
269
+ stream = drv.Stream()
270
+
271
+ # allocate device buffers
272
+ y = y.T.flatten().astype(np.float32)
273
+ y_gpu = drv.mem_alloc(y.nbytes)
274
+ drv.memcpy_htod_async(y_gpu, y.astype(dtype), stream)
275
+
276
+ theta_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
277
+ initial_theta = np.full(ZX, 0.1, dtype=dtype)
278
+ drv.memcpy_htod_async(theta_flat_gpu, initial_theta, stream)
279
+
280
+ norm_factor_inv_gpu = SMatrix.norm_factor_inv_gpu
281
+
282
+ q_flat_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
283
+ e_flat_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
284
+ c_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
285
+
286
+ # Assuming the cubin file is found globally or managed by the caller
287
+ mod = drv.module_from_file('AOT_biomaps_kernels.cubin')
288
+ projection_kernel = mod.get_function('projection_kernel__CSR')
289
+ backprojection_kernel = mod.get_function('backprojection_kernel__CSR')
290
+ ratio_kernel = mod.get_function('ratio_kernel')
291
+ update_kernel = mod.get_function('update_theta_kernel')
292
+ block_size = 256
293
+
294
+ saved_theta, saved_indices = [], []
295
+ if numIterations <= max_saves:
296
+ save_indices = list(range(numIterations))
297
+ else:
298
+ save_indices = list(range(0, numIterations, max(1, numIterations // max_saves)))
299
+ if save_indices[-1] != numIterations - 1:
300
+ save_indices.append(numIterations - 1)
215
301
 
216
- if isSavingEachIteration:
217
- step = max(1, (numIterations - 1) // 999)
218
- save_count = 1
219
-
220
- for it in trange(numIterations, desc=description):
221
- theta_p_list = []
222
- for i in range(num_gpus):
223
- with torch.cuda.device(f'cuda:{i}'):
224
- theta_p = theta_list[i].to(f'cuda:{i}')
225
- A_i = A_split[i].to(f'cuda:{i}')
226
- y_i = y_split[i].to(f'cuda:{i}')
227
- q_flat = A_i @ theta_p.reshape(-1)
228
- e_flat = y_i / (q_flat + torch.finfo(torch.float32).tiny)
229
- c_flat = A_i.T @ e_flat
230
- theta_p_plus_1_flat = (theta_p.reshape(-1) / (normalization_factor.to(f'cuda:{i}').reshape(-1) + torch.finfo(torch.float32).tiny)) * c_flat
231
- theta_p_plus_1 = theta_p_plus_1_flat.reshape(Z, X)
232
- theta_p_list.append(theta_p_plus_1)
233
-
234
- for i in range(num_gpus):
235
- theta_list[i] = theta_p_list[i].to('cuda:0')
236
-
237
- if isSavingEachIteration and (it % step == 0 or it == numIterations - 1):
238
- saved_theta.append(torch.stack(theta_p_list).mean(dim=0).cpu().numpy())
239
- saved_indices.append(it + 1)
240
- save_count += 1
241
- if save_count >= 1000:
242
- break
302
+ description = f"AOT-BioMaps -- ML-EM (CSR-sparse SMatrix) ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
303
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
304
+ for it in iterator:
305
+ # projection: q = A * theta
306
+ projection_kernel(q_flat_gpu, SMatrix.values_gpu, SMatrix.row_ptr_gpu, SMatrix.col_ind_gpu,
307
+ theta_flat_gpu, np.int32(TN),
308
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1),
309
+ stream=stream)
243
310
 
244
- del A_matrix_torch, y_torch, A_split, y_split, theta_0, normalization_factor
245
- torch.cuda.empty_cache()
246
- for i in range(num_gpus):
247
- torch.cuda.empty_cache()
311
+ # ratio: e = y / max(q, threshold)
312
+ ratio_kernel(e_flat_gpu, y_gpu, q_flat_gpu, np.float32(denominator_threshold), np.int32(TN),
313
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
314
+
315
+ # backprojection: c = A^T * e
316
+ drv.memset_d32_async(c_flat_gpu, 0, ZX, stream)
317
+ backprojection_kernel(c_flat_gpu, SMatrix.values_gpu, SMatrix.row_ptr_gpu, SMatrix.col_ind_gpu,
318
+ e_flat_gpu, np.int32(TN),
319
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
320
+
321
+ # update: theta *= norm_factor_inv * c
322
+ update_kernel(theta_flat_gpu, c_flat_gpu, norm_factor_inv_gpu, np.int32(ZX),
323
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
324
+
325
+ if show_logs and (it % 10 == 0 or it == numIterations - 1):
326
+ drv.Context.synchronize()
327
+
328
+ if isSavingEachIteration and it in save_indices:
329
+ theta_host = np.empty(ZX, dtype=dtype)
330
+ drv.memcpy_dtoh(theta_host, theta_flat_gpu)
331
+ saved_theta.append(theta_host.reshape(Z, X))
332
+ saved_indices.append(it)
333
+
334
+ drv.Context.synchronize()
335
+
336
+ final_result = np.empty(ZX, dtype=dtype)
337
+ drv.memcpy_dtoh(final_result, theta_flat_gpu)
338
+ final_result = final_result.reshape(Z, X)
339
+
340
+ # free local allocations
341
+ y_gpu.free(); q_flat_gpu.free(); e_flat_gpu.free(); c_flat_gpu.free(); theta_flat_gpu.free()
342
+
343
+ return (saved_theta, saved_indices) if isSavingEachIteration else (final_result, None)
248
344
 
249
- if not isSavingEachIteration:
250
- return torch.stack(theta_p_list).mean(dim=0).cpu().numpy(), None
251
- else:
252
- return saved_theta, saved_indices
253
345
  except Exception as e:
254
- print("Error in multi-GPU MLEM:", type(e).__name__, ":", e)
255
- del A_matrix_torch, y_torch, A_split, y_split, theta_0, normalization_factor
256
- torch.cuda.empty_cache()
257
- for i in range(num_gpus):
258
- torch.cuda.empty_cache()
346
+ print(f"Error in MLEM_sparseCSR_pycuda: {type(e).__name__}: {e}")
347
+ gc.collect()
259
348
  return None, None
349
+
350
+ finally:
351
+ # --- CONTEXT FIX: Pop the context ---
352
+ if SMatrix and hasattr(SMatrix, 'ctx') and SMatrix.ctx:
353
+ SMatrix.ctx.pop()
354
+ # ------------------------------------
355
+
356
+ def MLEM_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
357
+ """
358
+ MLEM using SELL-C-σ kernels already present on device.
359
+ y must be float32 length TN.
360
+ """
361
+ final_result = None
260
362
 
363
+ try:
364
+ # check if SMatrix is SparseSMatrix_SELL object
365
+ if not isinstance(SMatrix, SparseSMatrix_SELL):
366
+ raise TypeError("SMatrix must be a SparseSMatrix_SELL object")
367
+ if SMatrix.sell_values_gpu is None:
368
+ raise RuntimeError("SELL not built. Call allocate_sell_c_sigma_direct() first.")
369
+
370
+ # --- CONTEXT FIX: Push the context associated with SMatrix ---
371
+ # This ensures all subsequent PyCUDA operations use the correct GPU/context.
372
+ if SMatrix.ctx:
373
+ SMatrix.ctx.push()
374
+ # -----------------------------------------------------------
375
+
376
+ TN = int(SMatrix.N * SMatrix.T)
377
+ ZX = int(SMatrix.Z * SMatrix.X)
378
+ dtype = np.float32
379
+ block_size = 256
380
+
381
+ mod = drv.module_from_file('AOT_biomaps_kernels.cubin')
382
+ proj = mod.get_function("projection_kernel__SELL")
383
+ backproj = mod.get_function("backprojection_kernel__SELL")
384
+ ratio = mod.get_function("ratio_kernel")
385
+ update = mod.get_function("update_theta_kernel")
386
+
387
+ stream = drv.Stream()
388
+
389
+ # device buffers
390
+ y = y.T.flatten().astype(np.float32)
391
+ y_gpu = drv.mem_alloc(y.nbytes)
392
+ drv.memcpy_htod_async(y_gpu, y.astype(dtype), stream)
393
+
394
+ theta_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
395
+ drv.memcpy_htod_async(theta_gpu, np.full(ZX, 0.1, dtype=dtype), stream)
396
+
397
+ q_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
398
+ e_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
399
+ c_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
400
+
401
+ slice_ptr_gpu = SMatrix.slice_ptr_gpu
402
+ slice_len_gpu = SMatrix.slice_len_gpu
403
+ slice_height = np.int32(SMatrix.slice_height)
404
+
405
+ grid_rows = ((TN + block_size - 1) // block_size, 1, 1)
406
+ grid_cols = ((ZX + block_size - 1) // block_size, 1, 1)
407
+
408
+ saved_theta, saved_indices = [], []
409
+ if numIterations <= max_saves:
410
+ save_indices = list(range(numIterations))
411
+ else:
412
+ save_indices = list(range(0, numIterations, max(1, numIterations // max_saves)))
413
+ if save_indices[-1] != numIterations - 1:
414
+ save_indices.append(numIterations - 1)
415
+
416
+ description = f"AOT-BioMaps -- ML-EM (SELL-c-σ-sparse SMatrix) ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
417
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
418
+ for it in iterator:
419
+ # projection
420
+ proj(q_gpu, SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, slice_ptr_gpu, slice_len_gpu,
421
+ theta_gpu, np.int32(TN), slice_height,
422
+ block=(block_size,1,1), grid=grid_rows, stream=stream)
423
+
424
+ # ratio
425
+ ratio(e_gpu, y_gpu, q_gpu, np.float32(denominator_threshold), np.int32(TN),
426
+ block=(block_size,1,1), grid=grid_rows, stream=stream)
427
+
428
+ # zero c
429
+ drv.memset_d32_async(c_gpu, 0, ZX, stream)
430
+
431
+ # backprojection accumulate
432
+ backproj(SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, slice_ptr_gpu, slice_len_gpu,
433
+ e_gpu, c_gpu, np.int32(TN), slice_height,
434
+ block=(block_size,1,1), grid=grid_rows, stream=stream)
435
+
436
+ # update
437
+ update(theta_gpu, c_gpu, SMatrix.norm_factor_inv_gpu, np.int32(ZX),
438
+ block=(block_size,1,1), grid=grid_cols, stream=stream)
439
+
440
+ stream.synchronize()
441
+ if isSavingEachIteration and it in save_indices:
442
+ out = np.empty(ZX, dtype=np.float32)
443
+ drv.memcpy_dtoh(out, theta_gpu)
444
+ saved_theta.append(out.reshape((SMatrix.Z, SMatrix.X)))
445
+ saved_indices.append(it)
446
+
447
+ # final copy
448
+ res = np.empty(ZX, dtype=np.float32)
449
+ drv.memcpy_dtoh(res, theta_gpu)
450
+
451
+ # free temporaries
452
+ y_gpu.free(); q_gpu.free(); e_gpu.free(); c_gpu.free(); theta_gpu.free()
453
+
454
+ final_result = res.reshape((SMatrix.Z, SMatrix.X))
455
+ return (saved_theta, saved_indices) if isSavingEachIteration else (final_result, None)
456
+
457
+ except Exception as e:
458
+ print(f"Error in MLEM_sparseSELL_pycuda: {type(e).__name__}: {e}")
459
+ gc.collect()
460
+ return None, None
461
+
462
+ finally:
463
+ # --- CONTEXT FIX: Pop the context ---
464
+ if SMatrix and hasattr(SMatrix, 'ctx') and SMatrix.ctx:
465
+ SMatrix.ctx.pop()
466
+ # ------------------------------------