AOT-biomaps 2.9.261__py3-none-any.whl → 2.9.294__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of AOT-biomaps might be problematic. Click here for more details.

@@ -1,34 +1,57 @@
1
1
  from AOT_biomaps.Config import config
2
+ from AOT_biomaps.AOT_Recon.ReconTools import calculate_memory_requirement, check_gpu_memory
3
+ from AOT_biomaps.AOT_Recon.ReconEnums import SMatrixType
4
+
2
5
  import torch
3
6
  import numpy as np
4
7
  from tqdm import trange
5
- from AOT_biomaps.AOT_Recon.ReconTools import calculate_memory_requirement, check_gpu_memory
8
+ import pycuda.driver as drv
9
+ import torch.cuda
10
+ import gc
11
+
12
+
6
13
 
7
14
  def LS(
8
15
  SMatrix,
9
16
  y,
10
- numIterations=5000,
11
- alpha=1e-3,
17
+ numIterations=100,
12
18
  isSavingEachIteration=True,
13
19
  withTumor=True,
20
+ alpha=1e-1,
14
21
  device=None,
22
+ use_numba=False,
23
+ denominator_threshold=1e-6,
15
24
  max_saves=5000,
16
- show_logs=True
25
+ show_logs=True,
26
+ smatrixType=SMatrixType.SELL
17
27
  ):
18
28
  """
19
29
  Least Squares reconstruction using Projected Gradient Descent (PGD) with non-negativity constraint.
20
30
  Currently only implements the stable GPU version.
21
31
  """
22
32
  tumor_str = "WITH" if withTumor else "WITHOUT"
23
- # Force GPU usage for now
33
+ # Auto-select device and method
24
34
  if device is None:
25
35
  if torch.cuda.is_available() and check_gpu_memory(config.select_best_gpu(), calculate_memory_requirement(SMatrix, y), show_logs=show_logs):
26
- raise RuntimeError("CUDA is required for this implementation.")
27
- device = torch.device(f"cuda:{config.select_best_gpu()}")
36
+ device = torch.device(f"cuda:{config.select_best_gpu()}")
37
+ use_gpu = True
38
+ else:
39
+ device = torch.device("cpu")
40
+ use_gpu = False
41
+ else:
42
+ use_gpu = device.type == "cuda"
43
+ # Dispatch to the appropriate implementation
44
+ if use_gpu:
45
+ if smatrixType == SMatrixType.CSR:
46
+ return _LS_CG_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs)
47
+ elif smatrixType == SMatrixType.SELL:
48
+ return _LS_CG_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs)
49
+ elif smatrixType == SMatrixType.DENSE:
50
+ return _LS_GPU_stable(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold,show_logs)
51
+ else:
52
+ raise ValueError("Unsupported SMatrixType for GPU LS.")
28
53
  else:
29
- if device.type != "cuda":
30
- raise RuntimeError("Only GPU implementation is available for now.")
31
- return _LS_GPU_stable(SMatrix, y, numIterations, alpha, isSavingEachIteration, tumor_str, max_saves, show_logs=show_logs)
54
+ raise NotImplementedError("Only GPU implementations are currently available for LS.")
32
55
 
33
56
  def _LS_GPU_stable(SMatrix, y, numIterations, alpha, isSavingEachIteration, tumor_str, max_saves=5000, show_logs=True):
34
57
  """
@@ -104,3 +127,370 @@ def _LS_CPU_opti(*args, **kwargs):
104
127
 
105
128
  def _LS_CPU_basic(*args, **kwargs):
106
129
  raise NotImplementedError("Only _LS_GPU_stable is implemented for now.")
130
+
131
+ def _LS_CG_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
132
+ """
133
+ Reconstruction par Moindres Carrés (LS) via Gradient Conjugué (CG) sur format CSR.
134
+ Utilise les mêmes arguments que la fonction MLEM, sans sous-fonctions Python.
135
+
136
+ SMatrix: instance de SparseSMatrix_CSR (déjà allouée)
137
+ y: données mesurées (1D np.float32 de taille TN)
138
+ """
139
+ final_result = None
140
+
141
+ # Paramètres non utilisés dans CG mais conservés pour la signature: denominator_threshold, device
142
+
143
+ # --- Logique de Produit Scalaire (Intégrée) ---
144
+ def _dot_product_gpu(mod, a_ptr, b_ptr, N_int, stream):
145
+ block_size = 256
146
+ grid_size = (N_int + block_size - 1) // block_size
147
+
148
+ reduction_host = np.empty(grid_size, dtype=np.float32)
149
+ reduction_buffer = drv.mem_alloc(reduction_host.nbytes)
150
+
151
+ dot_kernel = mod.get_function("dot_product_reduction_kernel")
152
+
153
+ dot_kernel(reduction_buffer, a_ptr, b_ptr, np.int32(N_int),
154
+ block=(block_size, 1, 1), grid=(grid_size, 1, 1), stream=stream)
155
+
156
+ drv.memcpy_dtoh(reduction_host, reduction_buffer)
157
+ total_dot = np.sum(reduction_host)
158
+
159
+ reduction_buffer.free()
160
+ return total_dot
161
+ # -----------------------------------------------
162
+
163
+ try:
164
+ if not isinstance(SMatrix, SMatrix.__class__):
165
+ raise TypeError("SMatrix must be a SparseSMatrix_CSR object")
166
+
167
+ if SMatrix.ctx:
168
+ SMatrix.ctx.push()
169
+
170
+ dtype = np.float32
171
+ TN = SMatrix.N * SMatrix.T
172
+ ZX = SMatrix.Z * SMatrix.X
173
+ Z = SMatrix.Z
174
+ X = SMatrix.X
175
+ block_size = 256
176
+ tolerance = 1e-12
177
+
178
+ if show_logs:
179
+ print(f"Executing on GPU device index: {SMatrix.device.primary_context.device.name()}")
180
+ print(f"Dim X: {X}, Dim Z: {Z}, TN: {TN}, ZX: {ZX}")
181
+
182
+ stream = drv.Stream()
183
+
184
+ # Récupération des Kernels
185
+ projection_kernel = SMatrix.sparse_mod.get_function('projection_kernel__CSR')
186
+ backprojection_kernel = SMatrix.sparse_mod.get_function('backprojection_kernel__CSR')
187
+ axpby_kernel = SMatrix.sparse_mod.get_function("vector_axpby_kernel")
188
+ minus_axpy_kernel = SMatrix.sparse_mod.get_function("vector_minus_axpy_kernel")
189
+
190
+ # --- Allocation des buffers (Pointeurs Bruts) ---
191
+ y = y.T.flatten().astype(dtype)
192
+ y_gpu = drv.mem_alloc(y.nbytes)
193
+ drv.memcpy_htod_async(y_gpu, y.astype(dtype), stream)
194
+
195
+ theta_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize) # lambda
196
+ drv.memcpy_htod_async(theta_flat_gpu, np.full(ZX, 0.1, dtype=dtype), stream)
197
+
198
+ q_flat_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize) # q = A*p
199
+ r_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize) # r (residue)
200
+ p_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize) # p (direction)
201
+ z_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize) # z = A^T A p
202
+ ATy_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize) # A^T y (constant)
203
+
204
+ # --- Initialisation CG ---
205
+
206
+ # 1. ATy = A^T * y
207
+ drv.memset_d32_async(ATy_flat_gpu, 0, ZX, stream)
208
+ backprojection_kernel(ATy_flat_gpu, SMatrix.values_gpu, SMatrix.row_ptr_gpu, SMatrix.col_ind_gpu,
209
+ y_gpu, np.int32(TN),
210
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
211
+
212
+ # 2. q = A * theta_0
213
+ projection_kernel(q_flat_gpu, SMatrix.values_gpu, SMatrix.row_ptr_gpu, SMatrix.col_ind_gpu,
214
+ theta_flat_gpu, np.int32(TN),
215
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
216
+
217
+ # 3. r_temp = A^T * q = A^T A theta_0
218
+ drv.memset_d32_async(r_flat_gpu, 0, ZX, stream)
219
+ backprojection_kernel(r_flat_gpu, SMatrix.values_gpu, SMatrix.row_ptr_gpu, SMatrix.col_ind_gpu,
220
+ q_flat_gpu, np.int32(TN),
221
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
222
+
223
+ # 4. r_0 = ATy - r_temp (r = ATy + (-1)*r_temp)
224
+ axpby_kernel(r_flat_gpu, ATy_flat_gpu, r_flat_gpu,
225
+ np.float32(1.0), np.float32(-1.0), np.int32(ZX),
226
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
227
+
228
+ # 5. p_0 = r_0
229
+ drv.memcpy_dtod(p_flat_gpu, r_flat_gpu, ZX * np.dtype(dtype).itemsize)
230
+
231
+ # 6. rho_prev = ||r_0||^2
232
+ rho_prev = _dot_product_gpu(SMatrix.sparse_mod, r_flat_gpu, r_flat_gpu, ZX, stream)
233
+
234
+ # --- Boucle itérative ---
235
+ saved_theta, saved_indices = [], []
236
+ if numIterations <= max_saves:
237
+ save_indices = list(range(numIterations))
238
+ else:
239
+ save_indices = list(range(0, numIterations, max(1, numIterations // max_saves)))
240
+ if save_indices[-1] != numIterations - 1:
241
+ save_indices.append(numIterations - 1)
242
+
243
+ description = f"AOT-BioMaps -- LS-CG (CSR-sparse SMatrix) ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
244
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
245
+
246
+ for it in iterator:
247
+ # a. q = A * p
248
+ projection_kernel(q_flat_gpu, SMatrix.values_gpu, SMatrix.row_ptr_gpu, SMatrix.col_ind_gpu,
249
+ p_flat_gpu, np.int32(TN),
250
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
251
+
252
+ # b. z = A^T * q = A^T A p
253
+ drv.memset_d32_async(z_flat_gpu, 0, ZX, stream)
254
+ backprojection_kernel(z_flat_gpu, SMatrix.values_gpu, SMatrix.row_ptr_gpu, SMatrix.col_ind_gpu,
255
+ q_flat_gpu, np.int32(TN),
256
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
257
+
258
+ # c. alpha = rho_prev / <p, z>
259
+ pAp = _dot_product_gpu(SMatrix.sparse_mod, p_flat_gpu, z_flat_gpu, ZX, stream)
260
+
261
+ if abs(pAp) < 1e-15: break
262
+ alpha = rho_prev / pAp
263
+
264
+ # d. theta = theta + alpha * p
265
+ axpby_kernel(theta_flat_gpu, theta_flat_gpu, p_flat_gpu,
266
+ np.float32(1.0), alpha, np.int32(ZX),
267
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
268
+
269
+ # e. r = r - alpha * z
270
+ minus_axpy_kernel(r_flat_gpu, z_flat_gpu, alpha, np.int32(ZX),
271
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
272
+
273
+ # f. rho_curr = ||r||^2
274
+ rho_curr = _dot_product_gpu(SMatrix.sparse_mod, r_flat_gpu, r_flat_gpu, ZX, stream)
275
+
276
+ if rho_curr < tolerance: break
277
+
278
+ # g. beta = rho_curr / rho_prev
279
+ beta = rho_curr / rho_prev
280
+
281
+ # h. p = r + beta * p
282
+ axpby_kernel(p_flat_gpu, r_flat_gpu, p_flat_gpu,
283
+ np.float32(1.0), beta, np.int32(ZX),
284
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
285
+
286
+ rho_prev = rho_curr
287
+
288
+ if show_logs and (it % 10 == 0 or it == numIterations - 1):
289
+ drv.Context.synchronize()
290
+
291
+ if isSavingEachIteration and it in save_indices:
292
+ theta_host = np.empty(ZX, dtype=dtype)
293
+ drv.memcpy_dtoh(theta_host, theta_flat_gpu)
294
+ saved_theta.append(theta_host.reshape(Z, X))
295
+ saved_indices.append(it)
296
+
297
+ drv.Context.synchronize()
298
+
299
+ final_result = np.empty(ZX, dtype=dtype)
300
+ drv.memcpy_dtoh(final_result, theta_flat_gpu)
301
+ final_result = final_result.reshape(Z, X)
302
+
303
+ # Libération
304
+ y_gpu.free(); q_flat_gpu.free(); r_flat_gpu.free(); p_flat_gpu.free(); z_flat_gpu.free(); theta_flat_gpu.free(); ATy_flat_gpu.free()
305
+
306
+ return (saved_theta, saved_indices) if isSavingEachIteration else (final_result, None)
307
+
308
+ except Exception as e:
309
+ print(f"Error in LS_CG_sparseCSR_pycuda: {type(e).__name__}: {e}")
310
+ gc.collect()
311
+ return None, None
312
+
313
+ finally:
314
+ if SMatrix and hasattr(SMatrix, 'ctx') and SMatrix.ctx:
315
+ SMatrix.ctx.pop()
316
+
317
+ def _LS_CG_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
318
+ """
319
+ Reconstruction par Moindres Carrés (LS) via Gradient Conjugué (CG) sur format SELL-C-sigma.
320
+ Utilise les mêmes arguments que la fonction MLEM, sans sous-fonctions Python.
321
+
322
+ SMatrix: instance de SparseSMatrix_SELL (déjà allouée)
323
+ y: données mesurées (1D np.float32 de taille TN)
324
+ """
325
+ final_result = None
326
+
327
+ # --- Logique de Produit Scalaire (Intégrée) ---
328
+ def _dot_product_gpu(mod, a_ptr, b_ptr, N_int, stream):
329
+ block_size = 256
330
+ grid_size = (N_int + block_size - 1) // block_size
331
+
332
+ reduction_host = np.empty(grid_size, dtype=np.float32)
333
+ reduction_buffer = drv.mem_alloc(reduction_host.nbytes)
334
+
335
+ dot_kernel = mod.get_function("dot_product_reduction_kernel")
336
+
337
+ dot_kernel(reduction_buffer, a_ptr, b_ptr, np.int32(N_int),
338
+ block=(block_size, 1, 1), grid=(grid_size, 1, 1), stream=stream)
339
+
340
+ drv.memcpy_dtoh(reduction_host, reduction_buffer)
341
+ total_dot = np.sum(reduction_host)
342
+
343
+ reduction_buffer.free()
344
+ return total_dot
345
+ # -----------------------------------------------
346
+
347
+ try:
348
+ if not isinstance(SMatrix, SMatrix.__class__):
349
+ raise TypeError("SMatrix must be a SparseSMatrix_SELL object")
350
+ if SMatrix.sell_values_gpu is None:
351
+ raise RuntimeError("SELL not built. Call allocate_sell_c_sigma_direct() first.")
352
+
353
+ if SMatrix.ctx:
354
+ SMatrix.ctx.push()
355
+
356
+ dtype = np.float32
357
+ TN = int(SMatrix.N * SMatrix.T)
358
+ ZX = int(SMatrix.Z * SMatrix.X)
359
+ Z = SMatrix.Z
360
+ X = SMatrix.X
361
+ block_size = 256
362
+ tolerance = 1e-12
363
+
364
+ # Accès aux paramètres SELL
365
+ projection_kernel = SMatrix.sparse_mod.get_function("projection_kernel__SELL")
366
+ backprojection_kernel = SMatrix.sparse_mod.get_function("backprojection_kernel__SELL")
367
+ axpby_kernel = SMatrix.sparse_mod.get_function("vector_axpby_kernel")
368
+ minus_axpy_kernel = SMatrix.sparse_mod.get_function("vector_minus_axpy_kernel")
369
+ slice_height = np.int32(SMatrix.slice_height)
370
+ grid_rows = ((TN + block_size - 1) // block_size, 1, 1)
371
+
372
+ stream = drv.Stream()
373
+
374
+ # Allocation des buffers
375
+ y = y.T.flatten().astype(dtype)
376
+ y_gpu = drv.mem_alloc(y.nbytes)
377
+ drv.memcpy_htod_async(y_gpu, y.astype(dtype), stream)
378
+
379
+ theta_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
380
+ drv.memcpy_htod_async(theta_flat_gpu, np.full(ZX, 0.1, dtype=dtype), stream)
381
+
382
+ q_flat_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
383
+ r_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
384
+ p_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
385
+ z_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
386
+ ATy_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
387
+
388
+ # --- Initialisation CG ---
389
+
390
+ # 1. ATy = A^T * y
391
+ drv.memset_d32_async(ATy_flat_gpu, 0, ZX, stream)
392
+ backprojection_kernel(SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, SMatrix.slice_ptr_gpu, SMatrix.slice_len_gpu,
393
+ y_gpu, ATy_flat_gpu, np.int32(TN), slice_height,
394
+ block=(block_size, 1, 1), grid=grid_rows, stream=stream)
395
+
396
+ # 2. q = A * theta_0
397
+ projection_kernel(q_flat_gpu, SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, SMatrix.slice_ptr_gpu, SMatrix.slice_len_gpu,
398
+ theta_flat_gpu, np.int32(TN), slice_height,
399
+ block=(block_size, 1, 1), grid=grid_rows, stream=stream)
400
+
401
+ # 3. r_temp = A^T * q = A^T A theta_0
402
+ drv.memset_d32_async(r_flat_gpu, 0, ZX, stream)
403
+ backprojection_kernel(SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, SMatrix.slice_ptr_gpu, SMatrix.slice_len_gpu,
404
+ q_flat_gpu, r_flat_gpu, np.int32(TN), slice_height,
405
+ block=(block_size, 1, 1), grid=grid_rows, stream=stream)
406
+
407
+ # 4. r_0 = ATy - r_temp
408
+ axpby_kernel(r_flat_gpu, ATy_flat_gpu, r_flat_gpu,
409
+ np.float32(1.0), np.float32(-1.0), np.int32(ZX),
410
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
411
+
412
+ # 5. p_0 = r_0
413
+ drv.memcpy_dtod(p_flat_gpu, r_flat_gpu, ZX * np.dtype(dtype).itemsize)
414
+
415
+ # 6. rho_prev = ||r_0||^2
416
+ rho_prev = _dot_product_gpu(SMatrix.sparse_mod, r_flat_gpu, r_flat_gpu, ZX, stream)
417
+
418
+ # --- Boucle itérative ---
419
+ saved_theta, saved_indices = [], []
420
+ if numIterations <= max_saves:
421
+ save_indices = list(range(numIterations))
422
+ else:
423
+ save_indices = list(range(0, numIterations, max(1, numIterations // max_saves)))
424
+ if save_indices[-1] != numIterations - 1:
425
+ save_indices.append(numIterations - 1)
426
+
427
+ description = f"AOT-BioMaps -- LS-CG (SELL-c-σ-sparse SMatrix) ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
428
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
429
+
430
+ for it in iterator:
431
+ # a. q = A * p
432
+ projection_kernel(q_flat_gpu, SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, SMatrix.slice_ptr_gpu, SMatrix.slice_len_gpu,
433
+ p_flat_gpu, np.int32(TN), slice_height,
434
+ block=(block_size, 1, 1), grid=grid_rows, stream=stream)
435
+
436
+ # b. z = A^T * q = A^T A p
437
+ drv.memset_d32_async(z_flat_gpu, 0, ZX, stream)
438
+ backprojection_kernel(SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, SMatrix.slice_ptr_gpu, SMatrix.slice_len_gpu,
439
+ q_flat_gpu, z_flat_gpu, np.int32(TN), slice_height,
440
+ block=(block_size, 1, 1), grid=grid_rows, stream=stream)
441
+
442
+ # c. alpha = rho_prev / <p, z>
443
+ pAp = _dot_product_gpu(SMatrix.sparse_mod, p_flat_gpu, z_flat_gpu, ZX, stream)
444
+
445
+ if abs(pAp) < 1e-15: break
446
+ alpha = rho_prev / pAp
447
+
448
+ # d. theta = theta + alpha * p
449
+ axpby_kernel(theta_flat_gpu, theta_flat_gpu, p_flat_gpu,
450
+ np.float32(1.0), alpha, np.int32(ZX),
451
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
452
+
453
+ # e. r = r - alpha * z
454
+ minus_axpy_kernel(r_flat_gpu, z_flat_gpu, alpha, np.int32(ZX),
455
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
456
+
457
+ # f. rho_curr = ||r||^2
458
+ rho_curr = _dot_product_gpu(SMatrix.sparse_mod, r_flat_gpu, r_flat_gpu, ZX, stream)
459
+
460
+ if rho_curr < tolerance: break
461
+
462
+ # g. beta = rho_curr / rho_prev
463
+ beta = rho_curr / rho_prev
464
+
465
+ # h. p = r + beta * p
466
+ axpby_kernel(p_flat_gpu, r_flat_gpu, p_flat_gpu,
467
+ np.float32(1.0), beta, np.int32(ZX),
468
+ block=(block_size, 1, 1), grid=((ZX + block_size - 1) // block_size, 1, 1), stream=stream)
469
+
470
+ rho_prev = rho_curr
471
+
472
+ stream.synchronize()
473
+ if isSavingEachIteration and it in save_indices:
474
+ out = np.empty(ZX, dtype=dtype)
475
+ drv.memcpy_dtoh(out, theta_flat_gpu)
476
+ saved_theta.append(out.reshape((Z, X)))
477
+ saved_indices.append(it)
478
+
479
+ # final copy
480
+ res = np.empty(ZX, dtype=np.float32)
481
+ drv.memcpy_dtoh(res, theta_flat_gpu)
482
+ final_result = res.reshape((Z, X))
483
+
484
+ # free temporaries
485
+ y_gpu.free(); q_flat_gpu.free(); r_flat_gpu.free(); p_flat_gpu.free(); z_flat_gpu.free(); theta_flat_gpu.free(); ATy_flat_gpu.free()
486
+
487
+ return (saved_theta, saved_indices) if isSavingEachIteration else (final_result, None)
488
+
489
+ except Exception as e:
490
+ print(f"Error in LS_CG_sparseSELL_pycuda: {type(e).__name__}: {e}")
491
+ gc.collect()
492
+ return None, None
493
+
494
+ finally:
495
+ if SMatrix and hasattr(SMatrix, 'ctx') and SMatrix.ctx:
496
+ SMatrix.ctx.pop()
@@ -26,7 +26,6 @@ def MLEM(
26
26
  max_saves=5000,
27
27
  show_logs=True,
28
28
  smatrixType=SMatrixType.SELL,
29
- Z=350,
30
29
  ):
31
30
  """
32
31
  Unified MLEM algorithm for Acousto-Optic Tomography.
@@ -59,7 +58,7 @@ def MLEM(
59
58
  # Dispatch to the appropriate implementation
60
59
  if use_gpu:
61
60
  if smatrixType == SMatrixType.CSR:
62
- return MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, Z, show_logs)
61
+ return MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs)
63
62
  elif smatrixType == SMatrixType.SELL:
64
63
  return MLEM_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs)
65
64
  elif smatrixType == SMatrixType.DENSE:
@@ -229,7 +228,7 @@ def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str,
229
228
  print(f"Error in optimized CPU MLEM: {type(e).__name__}: {e}")
230
229
  return None, None
231
230
 
232
- def MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs=True):
231
+ def MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
233
232
  """
234
233
  SMatrix: instance of SparseMatrixGPU (already allocated)
235
234
  y: measured data (1D np.float32 of length TN)
@@ -237,25 +236,39 @@ def MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumo
237
236
  Assumptions:
238
237
  - SMatrix.values_gpu and SMatrix.col_ind_gpu and SMatrix.row_ptr_gpu are device pointers
239
238
  - SMatrix.norm_factor_inv_gpu exists
239
+ - SMatrix.ctx is the PyCUDA context for the target GPU.
240
240
  """
241
+
242
+ # We use a final_result placeholder to ensure it's defined outside the try block
243
+ final_result = None
244
+
241
245
  try:
242
246
  if not isinstance(SMatrix, SparseSMatrix_CSR):
243
247
  raise TypeError("SMatrix must be a SparseSMatrix_CSR object")
248
+
249
+ # --- CONTEXT FIX: Push the context associated with SMatrix ---
250
+ # This ensures all subsequent PyCUDA operations use the correct GPU/context.
251
+ if SMatrix.ctx:
252
+ SMatrix.ctx.push()
253
+ # -----------------------------------------------------------
254
+
244
255
  dtype = np.float32
245
256
  TN = SMatrix.N * SMatrix.T
246
257
  ZX = SMatrix.Z * SMatrix.X
247
- if Z is None:
248
- Z = SMatrix.Z
258
+ # Ensure Z and X are correctly defined for reshaping
259
+ Z = SMatrix.Z
249
260
  X = SMatrix.X
250
261
 
251
262
  if show_logs:
263
+ # We assume SMatrix was initialized using the correct device index.
264
+ print(f"Executing on GPU device index: {SMatrix.device.primary_context.device.name()}")
252
265
  print(f"Dim X: {X}, Dim Z: {Z}, TN: {TN}, ZX: {ZX}")
253
266
 
254
- # Use existing context from SMatrix
255
267
  # streams
256
268
  stream = drv.Stream()
257
269
 
258
270
  # allocate device buffers
271
+ y = y.T.flatten().astype(np.float32)
259
272
  y_gpu = drv.mem_alloc(y.nbytes)
260
273
  drv.memcpy_htod_async(y_gpu, y.astype(dtype), stream)
261
274
 
@@ -269,12 +282,11 @@ def MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumo
269
282
  e_flat_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize)
270
283
  c_flat_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize)
271
284
 
272
- mlem_mod = drv.module_from_file('AOT_biomaps_kernels.cubin')
273
- projection_kernel = mlem_mod.get_function('projection_kernel__CSR')
274
- backprojection_kernel = mlem_mod.get_function('backprojection_kernel__CSR')
275
- ratio_kernel = mlem_mod.get_function('ratio_kernel')
276
- update_kernel = mlem_mod.get_function('update_theta_kernel')
277
-
285
+ # Assuming the cubin file is found globally or managed by the caller
286
+ projection_kernel = SMatrix.sparse_mod.get_function('projection_kernel__CSR')
287
+ backprojection_kernel = SMatrix.sparse_mod.get_function('backprojection_kernel__CSR')
288
+ ratio_kernel = SMatrix.sparse_mod.get_function('ratio_kernel')
289
+ update_kernel = SMatrix.sparse_mod.get_function('update_theta_kernel')
278
290
  block_size = 256
279
291
 
280
292
  saved_theta, saved_indices = [], []
@@ -296,7 +308,7 @@ def MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumo
296
308
 
297
309
  # ratio: e = y / max(q, threshold)
298
310
  ratio_kernel(e_flat_gpu, y_gpu, q_flat_gpu, np.float32(denominator_threshold), np.int32(TN),
299
- block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
311
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
300
312
 
301
313
  # backprojection: c = A^T * e
302
314
  drv.memset_d32_async(c_flat_gpu, 0, ZX, stream)
@@ -319,45 +331,60 @@ def MLEM_sparseCSR_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumo
319
331
 
320
332
  drv.Context.synchronize()
321
333
 
322
- result = np.empty(ZX, dtype=dtype)
323
- drv.memcpy_dtoh(result, theta_flat_gpu)
324
- result = result.reshape(Z, X)
334
+ final_result = np.empty(ZX, dtype=dtype)
335
+ drv.memcpy_dtoh(final_result, theta_flat_gpu)
336
+ final_result = final_result.reshape(Z, X)
325
337
 
326
338
  # free local allocations
327
339
  y_gpu.free(); q_flat_gpu.free(); e_flat_gpu.free(); c_flat_gpu.free(); theta_flat_gpu.free()
328
340
 
329
- return (saved_theta, saved_indices) if isSavingEachIteration else (result, None)
341
+ return (saved_theta, saved_indices) if isSavingEachIteration else (final_result, None)
330
342
 
331
343
  except Exception as e:
332
344
  print(f"Error in MLEM_sparseCSR_pycuda: {type(e).__name__}: {e}")
333
345
  gc.collect()
334
346
  return None, None
347
+
348
+ finally:
349
+ # --- CONTEXT FIX: Pop the context ---
350
+ if SMatrix and hasattr(SMatrix, 'ctx') and SMatrix.ctx:
351
+ SMatrix.ctx.pop()
352
+ # ------------------------------------
335
353
 
336
- def MLEM_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs=True):
354
+ def MLEM_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
337
355
  """
338
356
  MLEM using SELL-C-σ kernels already present on device.
339
357
  y must be float32 length TN.
340
358
  """
359
+ final_result = None
360
+
341
361
  try:
342
362
  # check if SMatrix is SparseSMatrix_SELL object
343
363
  if not isinstance(SMatrix, SparseSMatrix_SELL):
344
364
  raise TypeError("SMatrix must be a SparseSMatrix_SELL object")
345
365
  if SMatrix.sell_values_gpu is None:
346
366
  raise RuntimeError("SELL not built. Call allocate_sell_c_sigma_direct() first.")
367
+
368
+ # --- CONTEXT FIX: Push the context associated with SMatrix ---
369
+ # This ensures all subsequent PyCUDA operations use the correct GPU/context.
370
+ if SMatrix.ctx:
371
+ SMatrix.ctx.push()
372
+ # -----------------------------------------------------------
373
+
347
374
  TN = int(SMatrix.N * SMatrix.T)
348
375
  ZX = int(SMatrix.Z * SMatrix.X)
349
376
  dtype = np.float32
350
377
  block_size = 256
351
378
 
352
- mod = SMatrix.sparse_mod
353
- proj = mod.get_function("projection_kernel__SELL")
354
- backproj = mod.get_function("backprojection_kernel__SELL")
355
- ratio = mod.get_function("ratio_kernel")
356
- update = mod.get_function("update_theta_kernel")
379
+ proj = SMatrix.sparse_mod.get_function("projection_kernel__SELL")
380
+ backproj = SMatrix.sparse_mod.get_function("backprojection_kernel__SELL")
381
+ ratio = SMatrix.sparse_mod.get_function("ratio_kernel")
382
+ update = SMatrix.sparse_mod.get_function("update_theta_kernel")
357
383
 
358
384
  stream = drv.Stream()
359
385
 
360
386
  # device buffers
387
+ y = y.T.flatten().astype(np.float32)
361
388
  y_gpu = drv.mem_alloc(y.nbytes)
362
389
  drv.memcpy_htod_async(y_gpu, y.astype(dtype), stream)
363
390
 
@@ -420,9 +447,17 @@ def MLEM_sparseSELL_pycuda(SMatrix, y, numIterations, isSavingEachIteration, tum
420
447
 
421
448
  # free temporaries
422
449
  y_gpu.free(); q_gpu.free(); e_gpu.free(); c_gpu.free(); theta_gpu.free()
423
- return (saved_theta, saved_indices) if isSavingEachIteration else (res.reshape((SMatrix.Z, SMatrix.X)), None)
450
+
451
+ final_result = res.reshape((SMatrix.Z, SMatrix.X))
452
+ return (saved_theta, saved_indices) if isSavingEachIteration else (final_result, None)
453
+
424
454
  except Exception as e:
425
455
  print(f"Error in MLEM_sparseSELL_pycuda: {type(e).__name__}: {e}")
426
456
  gc.collect()
427
457
  return None, None
428
-
458
+
459
+ finally:
460
+ # --- CONTEXT FIX: Pop the context ---
461
+ if SMatrix and hasattr(SMatrix, 'ctx') and SMatrix.ctx:
462
+ SMatrix.ctx.pop()
463
+ # ------------------------------------