AOT-biomaps 2.9.186__py3-none-any.whl → 2.9.294__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of AOT-biomaps might be problematic. Click here for more details.

Files changed (28) hide show
  1. AOT_biomaps/AOT_Acoustic/StructuredWave.py +2 -2
  2. AOT_biomaps/AOT_Acoustic/_mainAcoustic.py +11 -6
  3. AOT_biomaps/AOT_Experiment/Tomography.py +74 -4
  4. AOT_biomaps/AOT_Experiment/_mainExperiment.py +95 -55
  5. AOT_biomaps/AOT_Recon/AOT_Optimizers/DEPIERRO.py +48 -13
  6. AOT_biomaps/AOT_Recon/AOT_Optimizers/LS.py +406 -13
  7. AOT_biomaps/AOT_Recon/AOT_Optimizers/MAPEM.py +118 -38
  8. AOT_biomaps/AOT_Recon/AOT_Optimizers/MLEM.py +303 -102
  9. AOT_biomaps/AOT_Recon/AOT_Optimizers/PDHG.py +443 -12
  10. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/RelativeDifferences.py +10 -14
  11. AOT_biomaps/AOT_Recon/AOT_SparseSMatrix/SparseSMatrix_CSR.py +274 -0
  12. AOT_biomaps/AOT_Recon/AOT_SparseSMatrix/SparseSMatrix_SELL.py +328 -0
  13. AOT_biomaps/AOT_Recon/AOT_SparseSMatrix/__init__.py +2 -0
  14. AOT_biomaps/AOT_Recon/AOT_biomaps_kernels.cubin +0 -0
  15. AOT_biomaps/AOT_Recon/AlgebraicRecon.py +243 -113
  16. AOT_biomaps/AOT_Recon/AnalyticRecon.py +26 -41
  17. AOT_biomaps/AOT_Recon/BayesianRecon.py +81 -146
  18. AOT_biomaps/AOT_Recon/PrimalDualRecon.py +157 -94
  19. AOT_biomaps/AOT_Recon/ReconEnums.py +27 -2
  20. AOT_biomaps/AOT_Recon/ReconTools.py +229 -12
  21. AOT_biomaps/AOT_Recon/__init__.py +1 -0
  22. AOT_biomaps/AOT_Recon/_mainRecon.py +60 -53
  23. AOT_biomaps/__init__.py +4 -69
  24. {aot_biomaps-2.9.186.dist-info → aot_biomaps-2.9.294.dist-info}/METADATA +2 -1
  25. aot_biomaps-2.9.294.dist-info/RECORD +47 -0
  26. aot_biomaps-2.9.186.dist-info/RECORD +0 -43
  27. {aot_biomaps-2.9.186.dist-info → aot_biomaps-2.9.294.dist-info}/WHEEL +0 -0
  28. {aot_biomaps-2.9.186.dist-info → aot_biomaps-2.9.294.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,18 @@
1
1
  from AOT_biomaps.AOT_Recon.ReconTools import _forward_projection, _backward_projection, check_gpu_memory, calculate_memory_requirement
2
2
  from AOT_biomaps.Config import config
3
+ from AOT_biomaps.AOT_Recon.AOT_SparseSMatrix.SparseSMatrix_SELL import SparseSMatrix_SELL
4
+ from AOT_biomaps.AOT_Recon.AOT_SparseSMatrix.SparseSMatrix_CSR import SparseSMatrix_CSR
5
+ from AOT_biomaps.AOT_Recon.ReconEnums import SMatrixType
3
6
  import numba
4
7
  import torch
5
8
  import numpy as np
6
9
  import os
7
10
  from tqdm import trange
11
+ import cupy as cp
12
+ import cupyx.scipy.sparse as cpsparse
13
+ import gc
14
+ import pycuda.driver as drv
15
+
8
16
 
9
17
  def MLEM(
10
18
  SMatrix,
@@ -13,9 +21,11 @@ def MLEM(
13
21
  isSavingEachIteration=True,
14
22
  withTumor=True,
15
23
  device=None,
16
- use_multi_gpu=False,
17
24
  use_numba=False,
25
+ denominator_threshold=1e-6,
18
26
  max_saves=5000,
27
+ show_logs=True,
28
+ smatrixType=SMatrixType.SELL,
19
29
  ):
20
30
  """
21
31
  Unified MLEM algorithm for Acousto-Optic Tomography.
@@ -33,34 +43,38 @@ def MLEM(
33
43
  Returns:
34
44
  Reconstructed image(s) and iteration indices (if isSavingEachIteration)
35
45
  """
36
- try:
37
- tumor_str = "WITH" if withTumor else "WITHOUT"
38
- # Auto-select device and method
39
- if device is None:
40
- if torch.cuda.is_available() and check_gpu_memory(config.select_best_gpu(), calculate_memory_requirement(SMatrix, y)):
41
- device = torch.device(f"cuda:{config.select_best_gpu()}")
42
- use_gpu = True
43
- else:
44
- device = torch.device("cpu")
45
- use_gpu = False
46
+ # try:
47
+ tumor_str = "WITH" if withTumor else "WITHOUT"
48
+ # Auto-select device and method
49
+ if device is None:
50
+ if torch.cuda.is_available() and check_gpu_memory(config.select_best_gpu(), calculate_memory_requirement(SMatrix, y), show_logs=show_logs):
51
+ device = torch.device(f"cuda:{config.select_best_gpu()}")
52
+ use_gpu = True
46
53
  else:
47
- use_gpu = device.type == "cuda"
48
- # Dispatch to the appropriate implementation
49
- if use_gpu:
50
- if use_multi_gpu and torch.cuda.device_count() > 1:
51
- return _MLEM_multi_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves)
54
+ device = torch.device("cpu")
55
+ use_gpu = False
56
+ else:
57
+ use_gpu = device.type == "cuda"
58
+ # Dispatch to the appropriate implementation
59
+ if use_gpu:
60
+ if smatrixType == SMatrixType.CSR:
61
+ return MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs)
62
+ elif smatrixType == SMatrixType.SELL:
63
+ return MLEM_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs)
64
+ elif smatrixType == SMatrixType.DENSE:
65
+ return _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold,show_logs)
52
66
  else:
53
- return _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves)
67
+ raise ValueError("Unsupported SMatrixType for GPU MLEM.")
68
+ else:
69
+ if use_numba:
70
+ return _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs)
54
71
  else:
55
- if use_numba:
56
- return _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves)
57
- else:
58
- return _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves)
59
- except Exception as e:
60
- print(f"Error in MLEM: {type(e).__name__}: {e}")
61
- return None, None
72
+ return _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs)
73
+ # except Exception as e:
74
+ # print(f"Error in MLEM: {type(e).__name__}: {e}")
75
+ # return None, None
62
76
 
63
- def _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves=5000):
77
+ def _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
64
78
  try:
65
79
  eps = torch.finfo(torch.float32).eps
66
80
  T, Z, X, N = SMatrix.shape
@@ -82,7 +96,6 @@ def _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str
82
96
  .reshape(-1)
83
97
  )
84
98
  description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
85
-
86
99
  # Calculate save indices
87
100
  if numIterations <= max_saves:
88
101
  save_indices = list(range(numIterations))
@@ -91,20 +104,21 @@ def _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str
91
104
  save_indices = list(range(0, numIterations, step))
92
105
  if save_indices[-1] != numIterations - 1:
93
106
  save_indices.append(numIterations - 1)
94
-
95
107
  saved_theta = []
96
108
  saved_indices = []
97
-
98
109
  with torch.no_grad():
99
- for it in trange(numIterations, desc=description):
110
+ # Utilise range si show_logs=False, sinon trange
111
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
112
+ for it in iterator:
100
113
  q_flat = A_flat @ theta_flat
101
- e_flat = y_flat / (q_flat + eps)
114
+ # Appliquer le seuil : si q_flat < denominator_threshold, on met e_flat à 1 (comme dans le code C++)
115
+ mask = q_flat >= denominator_threshold
116
+ e_flat = torch.where(mask, y_flat / (q_flat + eps), torch.ones_like(q_flat))
102
117
  c_flat = A_flat.T @ e_flat
103
118
  theta_flat = (theta_flat / (norm_factor_flat + eps)) * c_flat
104
119
  if isSavingEachIteration and it in save_indices:
105
120
  saved_theta.append(theta_flat.reshape(Z, X).clone())
106
121
  saved_indices.append(it)
107
-
108
122
  # Free memory
109
123
  del A_flat, y_flat, norm_factor_flat
110
124
  torch.cuda.empty_cache()
@@ -117,74 +131,15 @@ def _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str
117
131
  torch.cuda.empty_cache()
118
132
  return None, None
119
133
 
120
- def _MLEM_multi_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves=5000):
121
- try:
122
- num_gpus = torch.cuda.device_count()
123
- device = torch.device('cuda:0')
124
- T, Z, X, N = SMatrix.shape
125
- A_matrix_torch = torch.tensor(SMatrix, dtype=torch.float32).to(device).permute(0, 3, 1, 2).reshape(T * N, Z * X)
126
- y_torch = torch.tensor(y, dtype=torch.float32).to(device).reshape(-1)
127
- A_split = torch.chunk(A_matrix_torch, num_gpus, dim=0)
128
- y_split = torch.chunk(y_torch, num_gpus)
129
- theta_0 = torch.ones((Z, X), dtype=torch.float32, device=device)
130
- theta_list = [theta_0.clone().to(device) for _ in range(num_gpus)]
131
- normalization_factor = A_matrix_torch.sum(dim=0).reshape(Z, X).to(device)
132
-
133
- # Calculate save indices
134
- if numIterations <= max_saves:
135
- save_indices = list(range(numIterations))
136
- else:
137
- step = numIterations // max_saves
138
- save_indices = list(range(0, numIterations, step))
139
- if save_indices[-1] != numIterations - 1:
140
- save_indices.append(numIterations - 1)
141
-
142
- saved_theta = [theta_0.cpu().numpy()]
143
- saved_indices = [0]
144
- description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- processing on multi-GPU ({num_gpus} GPUs) ----"
145
-
146
- for it in trange(numIterations, desc=description):
147
- theta_p_list = []
148
- for i in range(num_gpus):
149
- with torch.cuda.device(f'cuda:{i}'):
150
- theta_p = theta_list[i].to(f'cuda:{i}')
151
- A_i = A_split[i].to(f'cuda:{i}')
152
- y_i = y_split[i].to(f'cuda:{i}')
153
- q_flat = A_i @ theta_p.reshape(-1)
154
- e_flat = y_i / (q_flat + torch.finfo(torch.float32).tiny)
155
- c_flat = A_i.T @ e_flat
156
- theta_p_plus_1_flat = (theta_p.reshape(-1) / (normalization_factor.to(f'cuda:{i}').reshape(-1) + torch.finfo(torch.float32).tiny)) * c_flat
157
- theta_p_plus_1 = theta_p_plus_1_flat.reshape(Z, X)
158
- theta_p_list.append(theta_p_plus_1)
159
- for i in range(num_gpus):
160
- theta_list[i] = theta_p_list[i].to('cuda:0')
161
- if isSavingEachIteration and it in save_indices:
162
- saved_theta.append(torch.stack(theta_p_list).mean(dim=0).cpu().numpy())
163
- saved_indices.append(it + 1)
164
-
165
- del A_matrix_torch, y_torch, A_split, y_split, theta_0, normalization_factor
166
- for i in range(num_gpus):
167
- torch.cuda.empty_cache()
168
- if not isSavingEachIteration:
169
- return torch.stack(theta_p_list).mean(dim=0).cpu().numpy(), None
170
- else:
171
- return saved_theta, saved_indices
172
- except Exception as e:
173
- print(f"Error in multi-GPU MLEM: {type(e).__name__}: {e}")
174
- del A_matrix_torch, y_torch, A_split, y_split, theta_0, normalization_factor
175
- for i in range(num_gpus):
176
- torch.cuda.empty_cache()
177
- return None, None
178
-
179
- def _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves=5000):
134
+ def _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs=True):
180
135
  try:
181
136
  numba.set_num_threads(os.cpu_count())
182
- q_p = np.zeros((SMatrix.shape[0], SMatrix.shape[3]))
183
- c_p = np.zeros((SMatrix.shape[1], SMatrix.shape[2]))
184
- theta_p_0 = np.ones((SMatrix.shape[1], SMatrix.shape[2]))
137
+ q_p = np.zeros((SMatrix.shape[0], SMatrix.shape[3]), dtype=np.float32)
138
+ c_p = np.zeros((SMatrix.shape[1], SMatrix.shape[2]), dtype=np.float32)
139
+ theta_p_0 = np.ones((SMatrix.shape[1], SMatrix.shape[2]), dtype=np.float32)
185
140
  matrix_theta = [theta_p_0]
186
141
  saved_indices = [0]
187
- normalization_factor = np.sum(SMatrix, axis=(0, 3))
142
+ normalization_factor = np.sum(SMatrix, axis=(0, 3)).astype(np.float32)
188
143
 
189
144
  # Calculate save indices
190
145
  if numIterations <= max_saves:
@@ -196,14 +151,20 @@ def _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str,
196
151
  save_indices.append(numIterations - 1)
197
152
 
198
153
  description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- processing on multithread CPU ({numba.config.NUMBA_DEFAULT_NUM_THREADS} threads) ----"
154
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
199
155
 
200
- for it in trange(numIterations, desc=description):
156
+ for it in iterator:
201
157
  theta_p = matrix_theta[-1]
202
158
  _forward_projection(SMatrix, theta_p, q_p)
203
- e_p = y / (q_p + 1e-8)
159
+
160
+ # Appliquer le seuil : si q_p < denominator_threshold, on met e_p à 1
161
+ mask = q_p >= denominator_threshold
162
+ e_p = np.where(mask, y / (q_p + 1e-8), 1.0)
163
+
204
164
  _backward_projection(SMatrix, e_p, c_p)
205
165
  theta_p_plus_1 = theta_p / (normalization_factor + 1e-8) * c_p
206
- if isSavingEachIteration and it in save_indices:
166
+
167
+ if isSavingEachIteration and (it + 1) in save_indices:
207
168
  matrix_theta.append(theta_p_plus_1)
208
169
  saved_indices.append(it + 1)
209
170
  else:
@@ -217,7 +178,7 @@ def _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str,
217
178
  print(f"Error in Numba CPU MLEM: {type(e).__name__}: {e}")
218
179
  return None, None
219
180
 
220
- def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves=5000):
181
+ def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs=True):
221
182
  try:
222
183
  T, Z, X, N = SMatrix.shape
223
184
  A_flat = SMatrix.astype(np.float32).transpose(0, 3, 1, 2).reshape(T * N, Z * X)
@@ -238,16 +199,22 @@ def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str,
238
199
  save_indices.append(numIterations - 1)
239
200
 
240
201
  description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- processing on single CPU (optimized) ----"
202
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
241
203
 
242
- for it in trange(numIterations, desc=description):
204
+ for it in iterator:
243
205
  theta_p = matrix_theta[-1]
244
206
  theta_p_flat = theta_p.reshape(-1)
245
207
  q_flat = A_flat @ theta_p_flat
246
- e_flat = y_flat / (q_flat + np.finfo(np.float32).tiny)
208
+
209
+ # Appliquer le seuil : si q_flat < denominator_threshold, on met e_flat à 1
210
+ mask = q_flat >= denominator_threshold
211
+ e_flat = np.where(mask, y_flat / (q_flat + np.finfo(np.float32).tiny), 1.0)
212
+
247
213
  c_flat = A_flat.T @ e_flat
248
214
  theta_p_plus_1_flat = theta_p_flat / (normalization_factor_flat + np.finfo(np.float32).tiny) * c_flat
249
215
  theta_p_plus_1 = theta_p_plus_1_flat.reshape(Z, X)
250
- if isSavingEachIteration and it in save_indices:
216
+
217
+ if isSavingEachIteration and (it + 1) in save_indices:
251
218
  matrix_theta.append(theta_p_plus_1)
252
219
  saved_indices.append(it + 1)
253
220
  else:
@@ -260,3 +227,237 @@ def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str,
260
227
  except Exception as e:
261
228
  print(f"Error in optimized CPU MLEM: {type(e).__name__}: {e}")
262
229
  return None, None
230
+
231
+ def MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
232
+ """
233
+ SMatrix: instance of SparseMatrixGPU (already allocated)
234
+ y: measured data (1D np.float32 of length TN)
235
+
236
+ Assumptions:
237
+ - SMatrix.values_gpu and SMatrix.col_ind_gpu and SMatrix.row_ptr_gpu are device pointers
238
+ - SMatrix.norm_factor_inv_gpu exists
239
+ - SMatrix.ctx is the PyCUDA context for the target GPU.
240
+ """
241
+
242
+ # We use a final_result placeholder to ensure it's defined outside the try block
243
+ final_result = None
244
+
245
+ try:
246
+ if not isinstance(SMatrix, SparseSMatrix_CSR):
247
+ raise TypeError("SMatrix must be a SparseSMatrix_CSR object")
248
+
249
+ # --- CONTEXT FIX: Push the context associated with SMatrix ---
250
+ # This ensures all subsequent PyCUDA operations use the correct GPU/context.
251
+ if SMatrix.ctx:
252
+ SMatrix.ctx.push()
253
+ # -----------------------------------------------------------
254
+
255
+ dtype = np.float32
256
+ TN = SMatrix.N * SMatrix.T
257
+ ZX = SMatrix.Z * SMatrix.X
258
+ # Ensure Z and X are correctly defined for reshaping
259
+ Z = SMatrix.Z
260
+ X = SMatrix.X
261
+
262
+ if show_logs:
263
+ # We assume SMatrix was initialized using the correct device index.
264
+ print(f"Executing on GPU device index: {SMatrix.device.primary_context.device.name()}")
265
+ print(f"Dim X: {X}, Dim Z: {Z}, TN: {TN}, ZX: {ZX}")
266
+
267
+ # streams
268
+ stream = drv.Stream()
269
+
270
+ # allocate device buffers
271
+ y = y.T.flatten().astype(np.float32)
272
+ y_gpu = drv.mem_alloc(y.nbytes)
273
+ drv.memcpy_htod_async(y_gpu, y.astype(dtype), stream)
274
+
275
+ theta_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
276
+ initial_theta = np.full(ZX, 0.1, dtype=dtype)
277
+ drv.memcpy_htod_async(theta_flat_gpu, initial_theta, stream)
278
+
279
+ norm_factor_inv_gpu = SMatrix.norm_factor_inv_gpu
280
+
281
+ q_flat_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
282
+ e_flat_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
283
+ c_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
284
+
285
+ # Assuming the cubin file is found globally or managed by the caller
286
+ projection_kernel = SMatrix.sparse_mod.get_function('projection_kernel__CSR')
287
+ backprojection_kernel = SMatrix.sparse_mod.get_function('backprojection_kernel__CSR')
288
+ ratio_kernel = SMatrix.sparse_mod.get_function('ratio_kernel')
289
+ update_kernel = SMatrix.sparse_mod.get_function('update_theta_kernel')
290
+ block_size = 256
291
+
292
+ saved_theta, saved_indices = [], []
293
+ if numIterations <= max_saves:
294
+ save_indices = list(range(numIterations))
295
+ else:
296
+ save_indices = list(range(0, numIterations, max(1, numIterations // max_saves)))
297
+ if save_indices[-1] != numIterations - 1:
298
+ save_indices.append(numIterations - 1)
299
+
300
+ description = f"AOT-BioMaps -- ML-EM (CSR-sparse SMatrix) ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
301
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
302
+ for it in iterator:
303
+ # projection: q = A * theta
304
+ projection_kernel(q_flat_gpu, SMatrix.values_gpu, SMatrix.row_ptr_gpu, SMatrix.col_ind_gpu,
305
+ theta_flat_gpu, np.int32(TN),
306
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1),
307
+ stream=stream)
308
+
309
+ # ratio: e = y / max(q, threshold)
310
+ ratio_kernel(e_flat_gpu, y_gpu, q_flat_gpu, np.float32(denominator_threshold), np.int32(TN),
311
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
312
+
313
+ # backprojection: c = A^T * e
314
+ drv.memset_d32_async(c_flat_gpu, 0, ZX, stream)
315
+ backprojection_kernel(c_flat_gpu, SMatrix.values_gpu, SMatrix.row_ptr_gpu, SMatrix.col_ind_gpu,
316
+ e_flat_gpu, np.int32(TN),
317
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
318
+
319
+ # update: theta *= norm_factor_inv * c
320
+ update_kernel(theta_flat_gpu, c_flat_gpu, norm_factor_inv_gpu, np.int32(ZX),
321
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
322
+
323
+ if show_logs and (it % 10 == 0 or it == numIterations - 1):
324
+ drv.Context.synchronize()
325
+
326
+ if isSavingEachIteration and it in save_indices:
327
+ theta_host = np.empty(ZX, dtype=dtype)
328
+ drv.memcpy_dtoh(theta_host, theta_flat_gpu)
329
+ saved_theta.append(theta_host.reshape(Z, X))
330
+ saved_indices.append(it)
331
+
332
+ drv.Context.synchronize()
333
+
334
+ final_result = np.empty(ZX, dtype=dtype)
335
+ drv.memcpy_dtoh(final_result, theta_flat_gpu)
336
+ final_result = final_result.reshape(Z, X)
337
+
338
+ # free local allocations
339
+ y_gpu.free(); q_flat_gpu.free(); e_flat_gpu.free(); c_flat_gpu.free(); theta_flat_gpu.free()
340
+
341
+ return (saved_theta, saved_indices) if isSavingEachIteration else (final_result, None)
342
+
343
+ except Exception as e:
344
+ print(f"Error in MLEM_sparseCSR_pycuda: {type(e).__name__}: {e}")
345
+ gc.collect()
346
+ return None, None
347
+
348
+ finally:
349
+ # --- CONTEXT FIX: Pop the context ---
350
+ if SMatrix and hasattr(SMatrix, 'ctx') and SMatrix.ctx:
351
+ SMatrix.ctx.pop()
352
+ # ------------------------------------
353
+
354
+ def MLEM_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
355
+ """
356
+ MLEM using SELL-C-σ kernels already present on device.
357
+ y must be float32 length TN.
358
+ """
359
+ final_result = None
360
+
361
+ try:
362
+ # check if SMatrix is SparseSMatrix_SELL object
363
+ if not isinstance(SMatrix, SparseSMatrix_SELL):
364
+ raise TypeError("SMatrix must be a SparseSMatrix_SELL object")
365
+ if SMatrix.sell_values_gpu is None:
366
+ raise RuntimeError("SELL not built. Call allocate_sell_c_sigma_direct() first.")
367
+
368
+ # --- CONTEXT FIX: Push the context associated with SMatrix ---
369
+ # This ensures all subsequent PyCUDA operations use the correct GPU/context.
370
+ if SMatrix.ctx:
371
+ SMatrix.ctx.push()
372
+ # -----------------------------------------------------------
373
+
374
+ TN = int(SMatrix.N * SMatrix.T)
375
+ ZX = int(SMatrix.Z * SMatrix.X)
376
+ dtype = np.float32
377
+ block_size = 256
378
+
379
+ proj = SMatrix.sparse_mod.get_function("projection_kernel__SELL")
380
+ backproj = SMatrix.sparse_mod.get_function("backprojection_kernel__SELL")
381
+ ratio = SMatrix.sparse_mod.get_function("ratio_kernel")
382
+ update = SMatrix.sparse_mod.get_function("update_theta_kernel")
383
+
384
+ stream = drv.Stream()
385
+
386
+ # device buffers
387
+ y = y.T.flatten().astype(np.float32)
388
+ y_gpu = drv.mem_alloc(y.nbytes)
389
+ drv.memcpy_htod_async(y_gpu, y.astype(dtype), stream)
390
+
391
+ theta_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
392
+ drv.memcpy_htod_async(theta_gpu, np.full(ZX, 0.1, dtype=dtype), stream)
393
+
394
+ q_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
395
+ e_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
396
+ c_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
397
+
398
+ slice_ptr_gpu = SMatrix.slice_ptr_gpu
399
+ slice_len_gpu = SMatrix.slice_len_gpu
400
+ slice_height = np.int32(SMatrix.slice_height)
401
+
402
+ grid_rows = ((TN + block_size - 1) // block_size, 1, 1)
403
+ grid_cols = ((ZX + block_size - 1) // block_size, 1, 1)
404
+
405
+ saved_theta, saved_indices = [], []
406
+ if numIterations <= max_saves:
407
+ save_indices = list(range(numIterations))
408
+ else:
409
+ save_indices = list(range(0, numIterations, max(1, numIterations // max_saves)))
410
+ if save_indices[-1] != numIterations - 1:
411
+ save_indices.append(numIterations - 1)
412
+
413
+ description = f"AOT-BioMaps -- ML-EM (SELL-c-σ-sparse SMatrix) ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
414
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
415
+ for it in iterator:
416
+ # projection
417
+ proj(q_gpu, SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, slice_ptr_gpu, slice_len_gpu,
418
+ theta_gpu, np.int32(TN), slice_height,
419
+ block=(block_size,1,1), grid=grid_rows, stream=stream)
420
+
421
+ # ratio
422
+ ratio(e_gpu, y_gpu, q_gpu, np.float32(denominator_threshold), np.int32(TN),
423
+ block=(block_size,1,1), grid=grid_rows, stream=stream)
424
+
425
+ # zero c
426
+ drv.memset_d32_async(c_gpu, 0, ZX, stream)
427
+
428
+ # backprojection accumulate
429
+ backproj(SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, slice_ptr_gpu, slice_len_gpu,
430
+ e_gpu, c_gpu, np.int32(TN), slice_height,
431
+ block=(block_size,1,1), grid=grid_rows, stream=stream)
432
+
433
+ # update
434
+ update(theta_gpu, c_gpu, SMatrix.norm_factor_inv_gpu, np.int32(ZX),
435
+ block=(block_size,1,1), grid=grid_cols, stream=stream)
436
+
437
+ stream.synchronize()
438
+ if isSavingEachIteration and it in save_indices:
439
+ out = np.empty(ZX, dtype=np.float32)
440
+ drv.memcpy_dtoh(out, theta_gpu)
441
+ saved_theta.append(out.reshape((SMatrix.Z, SMatrix.X)))
442
+ saved_indices.append(it)
443
+
444
+ # final copy
445
+ res = np.empty(ZX, dtype=np.float32)
446
+ drv.memcpy_dtoh(res, theta_gpu)
447
+
448
+ # free temporaries
449
+ y_gpu.free(); q_gpu.free(); e_gpu.free(); c_gpu.free(); theta_gpu.free()
450
+
451
+ final_result = res.reshape((SMatrix.Z, SMatrix.X))
452
+ return (saved_theta, saved_indices) if isSavingEachIteration else (final_result, None)
453
+
454
+ except Exception as e:
455
+ print(f"Error in MLEM_sparseSELL_pycuda: {type(e).__name__}: {e}")
456
+ gc.collect()
457
+ return None, None
458
+
459
+ finally:
460
+ # --- CONTEXT FIX: Pop the context ---
461
+ if SMatrix and hasattr(SMatrix, 'ctx') and SMatrix.ctx:
462
+ SMatrix.ctx.pop()
463
+ # ------------------------------------