AOT-biomaps 2.9.176__py3-none-any.whl → 2.9.300__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of AOT-biomaps might be problematic. Click here for more details.

Files changed (29) hide show
  1. AOT_biomaps/AOT_Acoustic/StructuredWave.py +2 -2
  2. AOT_biomaps/AOT_Acoustic/_mainAcoustic.py +11 -6
  3. AOT_biomaps/AOT_Experiment/Tomography.py +74 -4
  4. AOT_biomaps/AOT_Experiment/_mainExperiment.py +95 -55
  5. AOT_biomaps/AOT_Recon/AOT_Optimizers/DEPIERRO.py +48 -13
  6. AOT_biomaps/AOT_Recon/AOT_Optimizers/LS.py +406 -13
  7. AOT_biomaps/AOT_Recon/AOT_Optimizers/MAPEM.py +118 -38
  8. AOT_biomaps/AOT_Recon/AOT_Optimizers/MLEM.py +390 -102
  9. AOT_biomaps/AOT_Recon/AOT_Optimizers/PDHG.py +443 -12
  10. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/RelativeDifferences.py +10 -14
  11. AOT_biomaps/AOT_Recon/AOT_SparseSMatrix/SparseSMatrix_CSR.py +274 -0
  12. AOT_biomaps/AOT_Recon/AOT_SparseSMatrix/SparseSMatrix_SELL.py +331 -0
  13. AOT_biomaps/AOT_Recon/AOT_SparseSMatrix/__init__.py +2 -0
  14. AOT_biomaps/AOT_Recon/AOT_biomaps_kernels.cubin +0 -0
  15. AOT_biomaps/AOT_Recon/AlgebraicRecon.py +259 -153
  16. AOT_biomaps/AOT_Recon/AnalyticRecon.py +27 -42
  17. AOT_biomaps/AOT_Recon/BayesianRecon.py +84 -151
  18. AOT_biomaps/AOT_Recon/DeepLearningRecon.py +1 -1
  19. AOT_biomaps/AOT_Recon/PrimalDualRecon.py +162 -102
  20. AOT_biomaps/AOT_Recon/ReconEnums.py +27 -2
  21. AOT_biomaps/AOT_Recon/ReconTools.py +229 -12
  22. AOT_biomaps/AOT_Recon/__init__.py +1 -0
  23. AOT_biomaps/AOT_Recon/_mainRecon.py +72 -58
  24. AOT_biomaps/__init__.py +4 -53
  25. {aot_biomaps-2.9.176.dist-info → aot_biomaps-2.9.300.dist-info}/METADATA +2 -1
  26. aot_biomaps-2.9.300.dist-info/RECORD +47 -0
  27. aot_biomaps-2.9.176.dist-info/RECORD +0 -43
  28. {aot_biomaps-2.9.176.dist-info → aot_biomaps-2.9.300.dist-info}/WHEEL +0 -0
  29. {aot_biomaps-2.9.176.dist-info → aot_biomaps-2.9.300.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
- from AOT_biomaps.AOT_Recon.ReconTools import power_method, gradient, div, proj_l2, prox_G, prox_F_star
1
+ from AOT_biomaps.AOT_Recon.ReconTools import power_method, gradient, div, proj_l2, prox_G, prox_F_star, _call_axpby, _call_minus_axpy, compute_TV_cpu, power_method_estimate_L__SELL, calculate_memory_requirement, check_gpu_memory
2
2
  from AOT_biomaps.Config import config
3
- from AOT_biomaps.AOT_Recon.ReconEnums import NoiseType
3
+ from AOT_biomaps.AOT_Recon.ReconEnums import NoiseType, SMatrixType
4
4
  import torch
5
5
  from tqdm import trange
6
+ import numpy as np
7
+ import pycuda.driver as drv
6
8
 
7
9
  '''
8
10
  This module implements Primal-Dual Hybrid Gradient (PDHG) methods for solving inverse problems in Acousto-Optic Tomography.
@@ -11,6 +13,103 @@ The methods can run on both CPU and GPU, with configurations set in the AOT_biom
11
13
  '''
12
14
 
13
15
  def CP_TV(
16
+ SMatrix,
17
+ y,
18
+ alpha=None, # TV regularization parameter (if None, alpha is auto-scaled)
19
+ beta=1e-4, # Tikhonov regularization parameter
20
+ theta=1.0,
21
+ numIterations=5000,
22
+ isSavingEachIteration=True,
23
+ L=None,
24
+ withTumor=True,
25
+ device=None,
26
+ max_saves=5000,
27
+ show_logs=True,
28
+ smatrixType=SMatrixType.SELL,
29
+ k_security=0.8,
30
+ use_power_method=True,
31
+ auto_alpha_gamma=0.05, # gamma for auto alpha: alpha = gamma * data_term / tv_term
32
+ apply_positivity_clamp=True,
33
+ tikhonov_as_gradient=False, # if True, apply -tau*2*beta*x instead of prox multiplicative
34
+ use_laplacian=True, # enable Laplacian (Hessian scalar) penalty
35
+ laplacian_beta_scale=1.0 # multiply beta for laplacian term if you want separate scaling
36
+ ):
37
+ # try:
38
+ tumor_str = "WITH" if withTumor else "WITHOUT"
39
+ # Auto-select device and method
40
+ if device is None:
41
+ if torch.cuda.is_available() and check_gpu_memory(config.select_best_gpu(), calculate_memory_requirement(SMatrix, y), show_logs=show_logs):
42
+ device = torch.device(f"cuda:{config.select_best_gpu()}")
43
+ use_gpu = True
44
+ else:
45
+ device = torch.device("cpu")
46
+ use_gpu = False
47
+ else:
48
+ use_gpu = device.type == "cuda"
49
+ # Dispatch to the appropriate implementation
50
+ if use_gpu:
51
+ if smatrixType == SMatrixType.CSR:
52
+ raise NotImplementedError("GPU Chambolle Pock (LS-TV) with CSR not implemented.")
53
+ elif smatrixType == SMatrixType.SELL:
54
+ return CP_TV_Tikhonov_sparseSELL_pycuda(SMatrix, y, alpha,beta, theta, numIterations, isSavingEachIteration, L, tumor_str, device, max_saves, show_logs, k_security, use_power_method, auto_alpha_gamma, apply_positivity_clamp, tikhonov_as_gradient, use_laplacian, laplacian_beta_scale)
55
+ elif smatrixType == SMatrixType.DENSE:
56
+ return CP_TV_dense(SMatrix, y, alpha, theta, numIterations, isSavingEachIteration, L, tumor_str, device, max_saves, show_logs)
57
+ else:
58
+ raise ValueError("Unsupported SMatrixType for GPU Chambolle Pock (LS-TV).")
59
+ else:
60
+ raise NotImplementedError("CPU Chambolle Pock (LS-TV) not implemented.")
61
+
62
+ def CP_KL(
63
+ SMatrix,
64
+ y,
65
+ alpha=None, # TV regularization parameter (if None, alpha is auto-scaled)
66
+ beta=1e-4, # Tikhonov regularization parameter
67
+ theta=1.0,
68
+ numIterations=5000,
69
+ isSavingEachIteration=True,
70
+ L=None,
71
+ withTumor=True,
72
+ device=None,
73
+ max_saves=5000,
74
+ show_logs=True,
75
+ smatrixType=SMatrixType.SELL,
76
+ k_security=0.8,
77
+ use_power_method=True,
78
+ auto_alpha_gamma=0.05, # gamma for auto alpha: alpha = gamma * data_term / tv_term
79
+ apply_positivity_clamp=True,
80
+ tikhonov_as_gradient=False, # if True, apply -tau*2*beta*x instead of prox multiplicative
81
+ use_laplacian=True, # enable Laplacian (Hessian scalar) penalty
82
+ laplacian_beta_scale=1.0 # multiply beta for laplacian term if you want separate scaling
83
+ ):
84
+ # try:
85
+ tumor_str = "WITH" if withTumor else "WITHOUT"
86
+ # Auto-select device and method
87
+ if device is None:
88
+ if torch.cuda.is_available() and check_gpu_memory(config.select_best_gpu(), calculate_memory_requirement(SMatrix, y), show_logs=show_logs):
89
+ device = torch.device(f"cuda:{config.select_best_gpu()}")
90
+ use_gpu = True
91
+ else:
92
+ device = torch.device("cpu")
93
+ use_gpu = False
94
+ else:
95
+ use_gpu = device.type == "cuda"
96
+ # Dispatch to the appropriate implementation
97
+ if use_gpu:
98
+ if smatrixType == SMatrixType.CSR:
99
+ raise NotImplementedError("GPU Chambolle Pock (LS-KL) with CSR not implemented.")
100
+ elif smatrixType == SMatrixType.SELL:
101
+ raise NotImplementedError("GPU Chambolle Pock (LS-KL) with SELL not implemented.")
102
+ elif smatrixType == SMatrixType.DENSE:
103
+ return CP_KL(SMatrix, y, alpha, theta, numIterations, isSavingEachIteration, L, tumor_str, device, max_saves, show_logs)
104
+ else:
105
+ raise ValueError("Unsupported SMatrixType for GPU Chambolle Pock (LS-KL).")
106
+ else:
107
+ raise NotImplementedError("CPU Chambolle Pock (LS-KL) not implemented.")
108
+
109
+
110
+
111
+
112
+ def CP_TV_dense(
14
113
  SMatrix,
15
114
  y,
16
115
  alpha=1e-1,
@@ -21,13 +120,14 @@ def CP_TV(
21
120
  withTumor=True,
22
121
  device=None,
23
122
  max_saves=5000,
123
+ show_logs=True,
24
124
  ):
25
125
  """
26
126
  Chambolle-Pock algorithm for Total Variation (TV) regularization.
27
127
  Works on both CPU and GPU.
28
128
  Args:
29
129
  SMatrix: System matrix (shape: T, Z, X, N)
30
- y: Measurement data (shape: T, X, N)
130
+ y: Measurement data (shape: T, N)
31
131
  alpha: Regularization parameter for TV
32
132
  theta: Relaxation parameter (1.0 for standard Chambolle-Pock)
33
133
  numIterations: Number of iterations
@@ -92,10 +192,10 @@ def CP_TV(
92
192
  # Description for progress bar
93
193
  tumor_str = "WITH TUMOR" if withTumor else "WITHOUT TUMOR"
94
194
  device_str = f"GPU no.{torch.cuda.current_device()}" if device.type == "cuda" else "CPU"
95
- description = f"AOT-BioMaps -- Primal/Dual Reconstruction (TV) α:{alpha:.4f} L:{L:.4f} -- {tumor_str} -- {device_str}"
195
+ description = f"AOT-BioMaps -- Primal/Dual Reconstruction (LS-TV) α:{alpha:.4f} L:{L:.4f} -- {tumor_str} -- {device_str}"
96
196
 
97
- # Main loop
98
- for iteration in trange(numIterations, desc=description):
197
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
198
+ for it in iterator:
99
199
  # Update p (TV proximal step)
100
200
  grad_x = gradient(x_tilde.reshape(Z, X))
101
201
  p = proj_l2(p + sigma * grad_x, alpha)
@@ -113,9 +213,9 @@ def CP_TV(
113
213
  x_tilde = x + theta * (x - x_old)
114
214
 
115
215
  # Save intermediate result if needed
116
- if isSavingEachIteration and iteration in save_indices:
216
+ if isSavingEachIteration and it in save_indices:
117
217
  I_reconMatrix.append(x.reshape(Z, X).clone() * (norm_y / norm_A))
118
- saved_indices.append(iteration)
218
+ saved_indices.append(it)
119
219
 
120
220
  # Return results
121
221
  if isSavingEachIteration:
@@ -123,6 +223,337 @@ def CP_TV(
123
223
  else:
124
224
  return (x.reshape(Z, X) * (norm_y / norm_A)).cpu().numpy(), None
125
225
 
226
+ def CP_TV_Tikhonov_sparseSELL_pycuda(
227
+ SMatrix,
228
+ y,
229
+ alpha=None, # TV regularization parameter (if None, alpha is auto-scaled)
230
+ beta=1e-4, # Tikhonov regularization parameter
231
+ theta=1.0,
232
+ numIterations=2000,
233
+ isSavingEachIteration=True,
234
+ L=None,
235
+ tumor_str="",
236
+ device=None,
237
+ max_saves=2000,
238
+ show_logs=True,
239
+ k_security=0.8,
240
+ use_power_method=True,
241
+ auto_alpha_gamma=0.05, # gamma for auto alpha: alpha = gamma * data_term / tv_term
242
+ apply_positivity_clamp=True,
243
+ tikhonov_as_gradient=False, # if True, apply -tau*2*beta*x instead of prox multiplicative
244
+ use_laplacian=True, # enable Laplacian (Hessian scalar) penalty
245
+ laplacian_beta_scale=1.0 # multiply beta for laplacian term if you want separate scaling
246
+ ):
247
+
248
+ """
249
+ CP-TV + Tikhonov + Laplacian (Hessian scalar) penalty integrated.
250
+ Returns (I_reconMatrix, saved_indices) if isSavingEachIteration else (x_final, None).
251
+ """
252
+ # ----- begin main -----
253
+ if SMatrix.ctx:
254
+ SMatrix.ctx.push()
255
+
256
+ # prepare variables
257
+ dtype = np.float32
258
+ TN = int(SMatrix.N * SMatrix.T)
259
+ ZX = int(SMatrix.Z * SMatrix.X)
260
+ Z, X = SMatrix.Z, SMatrix.X
261
+ block_size = 256
262
+
263
+ # existing kernels
264
+ projection_kernel = SMatrix.sparse_mod.get_function("projection_kernel__SELL")
265
+ backprojection_kernel = SMatrix.sparse_mod.get_function("backprojection_kernel__SELL")
266
+ axpby_kernel = SMatrix.sparse_mod.get_function("vector_axpby_kernel")
267
+ minus_axpy_kernel = SMatrix.sparse_mod.get_function("vector_minus_axpy_kernel")
268
+ gradient_kernel = SMatrix.sparse_mod.get_function("gradient_kernel")
269
+ divergence_kernel = SMatrix.sparse_mod.get_function("divergence_kernel")
270
+ proj_tv_kernel = SMatrix.sparse_mod.get_function("proj_tv_kernel")
271
+
272
+ # optional kernels (laplacian & clamp)
273
+ has_laplacian = False
274
+ has_clamp_kernel = False
275
+ try:
276
+ laplacian_kernel = SMatrix.sparse_mod.get_function("laplacian_kernel")
277
+ laplacian_adj_kernel = SMatrix.sparse_mod.get_function("laplacian_adj_kernel")
278
+ has_laplacian = True
279
+ except Exception:
280
+ has_laplacian = False
281
+
282
+ try:
283
+ clamp_positive_kernel = SMatrix.sparse_mod.get_function("clamp_positive_kernel")
284
+ has_clamp_kernel = True
285
+ except Exception:
286
+ has_clamp_kernel = False
287
+
288
+ stream = drv.Stream()
289
+
290
+ # estimate L operator norm if needed
291
+ if use_power_method or L is None:
292
+ L_LS_sq = power_method_estimate_L__SELL(SMatrix, stream, n_it=20, block_size=block_size)
293
+ L_nabla_sq = 8.0
294
+ L_op_norm = np.sqrt(L_LS_sq + L_nabla_sq)
295
+ if L_op_norm < 1e-6:
296
+ L_op_norm = 1.0
297
+ else:
298
+ L_op_norm = L
299
+
300
+ tau = np.float32(k_security / L_op_norm)
301
+ sigma = np.float32(k_security / L_op_norm)
302
+
303
+ # prepare y and normalization
304
+ y = y.T.astype(dtype).reshape(-1)
305
+ maxy = float(np.max(np.abs(y))) if y.size > 0 else 0.0
306
+ if maxy > 0:
307
+ y_normed = (y / maxy).copy()
308
+ else:
309
+ y_normed = y.copy()
310
+
311
+ # GPU allocations
312
+ bufs = []
313
+ y_gpu = drv.mem_alloc(y_normed.nbytes); bufs.append(y_gpu)
314
+ drv.memcpy_htod_async(y_gpu, y_normed.T.flatten(), stream)
315
+
316
+ x_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize); bufs.append(x_gpu)
317
+ drv.memset_d32_async(x_gpu, 0, ZX, stream)
318
+ x_old_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize); bufs.append(x_old_gpu)
319
+ x_tilde_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize); bufs.append(x_tilde_gpu)
320
+ drv.memcpy_dtod_async(x_tilde_gpu, x_gpu, ZX * np.dtype(dtype).itemsize, stream)
321
+
322
+ p_gpu = drv.mem_alloc(2 * ZX * np.dtype(dtype).itemsize); bufs.append(p_gpu)
323
+ q_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize); bufs.append(q_gpu)
324
+ drv.memset_d32_async(p_gpu, 0, 2 * ZX, stream)
325
+ drv.memset_d32_async(q_gpu, 0, TN, stream)
326
+
327
+ grad_gpu = drv.mem_alloc(2 * ZX * np.dtype(dtype).itemsize); bufs.append(grad_gpu)
328
+ div_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize); bufs.append(div_gpu)
329
+ Ax_gpu = drv.mem_alloc(TN * np.dtype(dtype).itemsize); bufs.append(Ax_gpu)
330
+ ATq_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize); bufs.append(ATq_gpu)
331
+ zero_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize); bufs.append(zero_gpu)
332
+ drv.memset_d32_async(zero_gpu, 0, ZX, stream)
333
+
334
+ # Laplacian buffers (if enabled and kernel available)
335
+ use_lap = use_laplacian and has_laplacian and (beta > 0)
336
+ if use_lap:
337
+ lap_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize); bufs.append(lap_gpu)
338
+ r_gpu = drv.mem_alloc(ZX * np.dtype(dtype).itemsize); bufs.append(r_gpu)
339
+ drv.memset_d32_async(r_gpu, 0, ZX, stream)
340
+ # scalar beta for laplacian (allow separate scale)
341
+ beta_lap = float(beta) * float(laplacian_beta_scale)
342
+ inv_1_plus_sigma_beta = np.float32(1.0 / (1.0 + float(sigma) * beta_lap))
343
+
344
+ # host buffers for logs
345
+ x_host = np.empty(ZX, dtype=dtype)
346
+ Ax_host = np.empty(TN, dtype=dtype)
347
+ q_host = np.empty(TN, dtype=dtype)
348
+ p_host = np.empty(2 * ZX, dtype=dtype)
349
+ ATq_host = np.empty(ZX, dtype=dtype)
350
+
351
+ # compute initial backprojection for auto-alpha
352
+ drv.memset_d32_async(ATq_gpu, 0, ZX, stream)
353
+ backprojection_kernel(SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, SMatrix.slice_ptr_gpu, SMatrix.slice_len_gpu,
354
+ y_gpu, ATq_gpu, np.int32(TN), np.int32(SMatrix.slice_height),
355
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
356
+ stream.synchronize()
357
+ drv.memcpy_dtoh(x_host, ATq_gpu)
358
+
359
+ # auto alpha if requested
360
+ if alpha is None:
361
+ drv.memcpy_htod_async(x_gpu, x_host, stream)
362
+ projection_kernel(Ax_gpu, SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, SMatrix.slice_ptr_gpu, SMatrix.slice_len_gpu,
363
+ x_gpu, np.int32(TN), np.int32(SMatrix.slice_height),
364
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
365
+ stream.synchronize()
366
+ drv.memcpy_dtoh(Ax_host, Ax_gpu)
367
+ resid = Ax_host - y_normed[:TN]
368
+ data_term = 0.5 * float(np.dot(resid, resid))
369
+ tv_term = float(compute_TV_cpu(x_host, Z, X)) + 1e-12
370
+ alpha = float(auto_alpha_gamma * data_term / tv_term)
371
+ if show_logs:
372
+ print(f"[auto-alpha] data_term={data_term:.6e}, tv_term={tv_term:.6e}, alpha_set={alpha:.6e}")
373
+
374
+ # tikhonov prox multiplicative scale
375
+ if tikhonov_as_gradient:
376
+ tikh_scale = None
377
+ else:
378
+ tikh_scale = np.float32(1.0 / (1.0 + 2.0 * tau * beta)) if beta > 0 else np.float32(1.0)
379
+
380
+ # saving policy
381
+ if numIterations <= max_saves:
382
+ save_indices_all = list(range(0, numIterations + 1))
383
+ else:
384
+ step = max(1, numIterations // max_saves)
385
+ save_indices_all = list(range(0, numIterations + 1, step))
386
+
387
+ device_str = f"GPU no.{torch.cuda.current_device()}" if device.type == "cuda" else "CPU"
388
+ if show_logs:
389
+ if (alpha is None or alpha == 0) and (beta is None or beta == 0):
390
+ print(f"Parameters: L={L_op_norm:.6e} tau={tau:.3e} sigma={sigma:.3e} lap_enabled={use_lap}")
391
+ description = f"AOT-BioMaps -- Primal/Dual Reconstruction (LS) -- {tumor_str} -- {device_str}"
392
+ if alpha is None or alpha == 0:
393
+ print(f"Parameters: L={L_op_norm:.6e} tau={tau:.3e} sigma={sigma:.3e} beta={beta:.4e} lap_enabled={use_lap}")
394
+ description = f"AOT-BioMaps -- Primal/Dual Reconstruction (LS-Tikhonov) -- {tumor_str} -- {device_str}"
395
+ elif beta is None or beta == 0:
396
+ print(f"Parameters: L={L_op_norm:.6e} tau={tau:.3e} sigma={sigma:.3e} alpha={alpha:.4e} beta={beta:.4e} lap_enabled={use_lap}")
397
+ description = f"AOT-BioMaps -- Primal/Dual Reconstruction (LS-TV) -- {tumor_str} -- {device_str}"
398
+ else:
399
+ print(f"Parameters: L={L_op_norm:.6e} tau={tau:.3e} sigma={sigma:.3e} alpha={alpha:.4e} beta={beta:.4e} lap_enabled={use_lap}")
400
+ description = f"AOT-BioMaps -- Primal/Dual Reconstruction (LS-TV-Tikhonov) -- {tumor_str} -- {device_str}"
401
+
402
+ I_reconMatrix = []
403
+ saved_indices = []
404
+ if isSavingEachIteration and 0 in save_indices_all:
405
+ drv.memcpy_dtoh(x_host, x_gpu)
406
+ x0 = x_host.reshape((Z, X)).copy()
407
+ if maxy > 0:
408
+ x0 *= maxy
409
+ I_reconMatrix.append(x0)
410
+ saved_indices.append(0)
411
+
412
+ # main loop
413
+ try:
414
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
415
+ for it in iterator:
416
+ # 1) dual p update (TV)
417
+ gradient_kernel(grad_gpu, x_tilde_gpu, np.int32(Z), np.int32(X), np.int32(ZX),
418
+ block=(block_size, 1, 1),
419
+ grid=((X + block_size - 1) // block_size, (Z + block_size - 1) // block_size, 1),
420
+ stream=stream)
421
+ _call_axpby(axpby_kernel, p_gpu, p_gpu, grad_gpu, 1.0, sigma, 2 * ZX, stream, block_size)
422
+ proj_tv_kernel(p_gpu, np.float32(alpha), np.int32(ZX),
423
+ block=(block_size, 1, 1),
424
+ grid=((ZX + block_size - 1) // block_size, 1, 1),
425
+ stream=stream)
426
+
427
+ # 2) dual q update (data fidelity)
428
+ projection_kernel(Ax_gpu, SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, SMatrix.slice_ptr_gpu, SMatrix.slice_len_gpu,
429
+ x_tilde_gpu, np.int32(TN), np.int32(SMatrix.slice_height),
430
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
431
+ _call_axpby(axpby_kernel, Ax_gpu, Ax_gpu, y_gpu, 1.0, -1.0, TN, stream, block_size)
432
+ _call_axpby(axpby_kernel, q_gpu, q_gpu, Ax_gpu, 1.0 / (1.0 + sigma), sigma / (1.0 + sigma), TN, stream, block_size)
433
+
434
+ # optional Laplacian dual update
435
+ if use_lap:
436
+ # compute Laplacian of x_tilde -> lap_gpu
437
+ laplacian_kernel(lap_gpu, x_tilde_gpu, np.int32(Z), np.int32(X), np.int32(ZX),
438
+ block=(block_size, 1, 1),
439
+ grid=((X + block_size - 1) // block_size, (Z + block_size - 1) // block_size, 1),
440
+ stream=stream)
441
+ # r = r + sigma * lap
442
+ _call_axpby(axpby_kernel, r_gpu, r_gpu, lap_gpu, 1.0, sigma, ZX, stream, block_size)
443
+ # r = r / (1 + sigma * beta_lap)
444
+ _call_axpby(axpby_kernel, r_gpu, r_gpu, zero_gpu, inv_1_plus_sigma_beta, 0.0, ZX, stream, block_size)
445
+
446
+ # 3) primal x update
447
+ drv.memcpy_dtod_async(x_old_gpu, x_gpu, ZX * np.dtype(dtype).itemsize, stream)
448
+ divergence_kernel(div_gpu, p_gpu, np.int32(Z), np.int32(X), np.int32(ZX),
449
+ block=(block_size, 1, 1),
450
+ grid=((X + block_size - 1) // block_size, (Z + block_size - 1) // block_size, 1),
451
+ stream=stream)
452
+ drv.memset_d32_async(ATq_gpu, 0, ZX, stream)
453
+ backprojection_kernel(SMatrix.sell_values_gpu, SMatrix.sell_colinds_gpu, SMatrix.slice_ptr_gpu, SMatrix.slice_len_gpu,
454
+ q_gpu, ATq_gpu, np.int32(TN), np.int32(SMatrix.slice_height),
455
+ block=(block_size, 1, 1), grid=((TN + block_size - 1) // block_size, 1, 1), stream=stream)
456
+ # ATq - div
457
+ _call_minus_axpy(minus_axpy_kernel, ATq_gpu, div_gpu, 1.0, ZX, stream, block_size)
458
+
459
+ # if laplacian is used, add H^T r into ATq
460
+ if use_lap:
461
+ # compute laplacian_adj_kernel(temp, r)
462
+ # reuse grad_gpu as temporary if safe (its content used earlier, but not reused until later)
463
+ laplacian_adj_kernel(grad_gpu, r_gpu, np.int32(Z), np.int32(X), np.int32(ZX),
464
+ block=(block_size, 1, 1),
465
+ grid=((X + block_size - 1) // block_size, (Z + block_size - 1) // block_size, 1),
466
+ stream=stream)
467
+ # ATq_gpu += temp (grad_gpu)
468
+ _call_axpby(axpby_kernel, ATq_gpu, ATq_gpu, grad_gpu, 1.0, 1.0, ZX, stream, block_size)
469
+
470
+ # x = x_old - tau * ATq_buffer
471
+ _call_minus_axpy(minus_axpy_kernel, x_gpu, ATq_gpu, tau, ZX, stream, block_size)
472
+
473
+ # Tikhonov
474
+ if beta > 0:
475
+ if tikhonov_as_gradient:
476
+ mul = 1.0 - 2.0 * float(tau) * float(beta)
477
+ if mul <= 0.0:
478
+ # fallback to prox multiplicative stable
479
+ fallback_scale = np.float32(1.0 / (1.0 + 2.0 * float(tau) * float(beta)))
480
+ _call_axpby(axpby_kernel, x_gpu, x_gpu, zero_gpu, fallback_scale, 0.0, ZX, stream, block_size)
481
+ else:
482
+ # x *= mul => implemented as axpby: out = 1* x + (mul-1)*x
483
+ _call_axpby(axpby_kernel, x_gpu, x_gpu, x_gpu, 1.0, np.float32(mul - 1.0), ZX, stream, block_size)
484
+ else:
485
+ _call_axpby(axpby_kernel, x_gpu, x_gpu, zero_gpu, tikh_scale, np.float32(0.0), ZX, stream, block_size)
486
+
487
+ # positivity clamp (prefer GPU kernel if available)
488
+ if apply_positivity_clamp:
489
+ if has_clamp_kernel:
490
+ # in-place clamp on GPU
491
+ clamp_positive_kernel(x_gpu, np.int32(ZX),
492
+ block=(block_size, 1, 1),
493
+ grid=((ZX + block_size - 1) // block_size, 1, 1),
494
+ stream=stream)
495
+ else:
496
+ # fallback CPU roundtrip (slower)
497
+ stream.synchronize()
498
+ drv.memcpy_dtoh(x_host, x_gpu)
499
+ np.maximum(x_host, 0.0, out=x_host)
500
+ drv.memcpy_htod_async(x_gpu, x_host, stream)
501
+
502
+ # extrapolation
503
+ _call_axpby(axpby_kernel, x_tilde_gpu, x_gpu, x_old_gpu, np.float32(1.0 + theta), np.float32(-theta), ZX, stream, block_size)
504
+
505
+ # saves
506
+ if isSavingEachIteration and (it + 1) in save_indices_all:
507
+ stream.synchronize()
508
+ drv.memcpy_dtoh(x_host, x_gpu)
509
+ x_saved = x_host.reshape((Z, X)).copy()
510
+ if maxy > 0:
511
+ x_saved *= maxy
512
+ I_reconMatrix.append(x_saved)
513
+ saved_indices.append(it + 1)
514
+
515
+ stream.synchronize()
516
+ drv.memcpy_dtoh(x_host, x_gpu)
517
+ x_final = x_host.reshape((Z, X)).copy()
518
+ if maxy > 0:
519
+ x_final *= maxy
520
+ if isSavingEachIteration and len(I_reconMatrix):
521
+ for i in range(len(I_reconMatrix)):
522
+ I_reconMatrix[i] *= maxy
523
+
524
+ # free buffers
525
+ for buff in bufs:
526
+ try:
527
+ buff.free()
528
+ except:
529
+ pass
530
+
531
+ if SMatrix.ctx:
532
+ SMatrix.ctx.pop()
533
+
534
+ if isSavingEachIteration:
535
+ return I_reconMatrix, saved_indices
536
+ else:
537
+ return x_final, None
538
+
539
+ except Exception as e:
540
+ # cleanup robustly
541
+ print("Error in CP_TV_Tikhonov+Lap (robust):", e)
542
+ try:
543
+ for buff in bufs:
544
+ try:
545
+ buff.free()
546
+ except:
547
+ pass
548
+ except:
549
+ pass
550
+ try:
551
+ if SMatrix and hasattr(SMatrix, 'ctx') and SMatrix.ctx:
552
+ SMatrix.ctx.pop()
553
+ except:
554
+ pass
555
+ raise
556
+
126
557
 
127
558
  def CP_KL(
128
559
  SMatrix,
@@ -132,9 +563,10 @@ def CP_KL(
132
563
  numIterations=5000,
133
564
  isSavingEachIteration=True,
134
565
  L=None,
135
- withTumor=True,
566
+ tumor_str="",
136
567
  device=None,
137
568
  max_saves=5000,
569
+ show_logs=True,
138
570
  ):
139
571
  """
140
572
  Chambolle-Pock algorithm for Kullback-Leibler (KL) divergence regularization.
@@ -193,12 +625,11 @@ def CP_KL(
193
625
  saved_indices = [0]
194
626
 
195
627
  # Description for progress bar
196
- tumor_str = "WITH TUMOR" if withTumor else "WITHOUT TUMOR"
197
628
  device_str = f"GPU no.{torch.cuda.current_device()}" if device.type == "cuda" else "CPU"
198
629
  description = f"AOT-BioMaps -- Primal/Dual Reconstruction (KL) α:{alpha:.4f} L:{L:.4f} -- {tumor_str} -- {device_str}"
199
630
 
200
- # Main loop
201
- for iteration in trange(numIterations, desc=description):
631
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
632
+ for iteration in iterator:
202
633
  # Update q (proximal step for F*)
203
634
  q = prox_F_star(q + sigma * P(x_tilde) - sigma * y_flat, sigma, y_flat)
204
635
 
@@ -9,26 +9,24 @@ def _Omega_RELATIVE_DIFFERENCE_CPU(theta_flat, index, values, gamma):
9
9
  theta_k = theta_flat[k_idx]
10
10
  diff = theta_k - theta_j
11
11
  abs_diff = np.abs(diff)
12
-
13
12
  denom = theta_k + theta_j + gamma * abs_diff + 1e-8
14
13
  num = diff ** 2
15
-
14
+ psi_pair = num / denom
15
+ psi_pair = values * psi_pair
16
16
  # First derivative ∂U/∂θ_j
17
17
  dpsi = (2 * diff * denom - num * (1 + gamma * np.sign(diff))) / (denom ** 2)
18
18
  grad_pair = values * (-dpsi) # Note the negative sign: U contains ψ(θ_k, θ_j), seeking ∂/∂θ_j
19
-
20
19
  # Second derivative ∂²U/∂θ_j² (numerically stable, approximate treatment)
21
20
  d2psi = (2 * denom ** 2 - 4 * diff * denom * (1 + gamma * np.sign(diff))
22
21
  + 2 * num * (1 + gamma * np.sign(diff)) ** 2) / (denom ** 3 + 1e-8)
23
22
  hess_pair = values * d2psi
24
-
25
23
  grad_U = np.zeros_like(theta_flat)
26
24
  hess_U = np.zeros_like(theta_flat)
27
-
28
25
  np.add.at(grad_U, j_idx, grad_pair)
29
26
  np.add.at(hess_U, j_idx, hess_pair)
30
-
31
- return grad_U, hess_U
27
+ # Compute U_value
28
+ U_value = 0.5 * np.sum(psi_pair)
29
+ return grad_U, hess_U, U_value
32
30
 
33
31
  def _Omega_RELATIVE_DIFFERENCE_GPU(theta_flat, index, values, device, gamma):
34
32
  j_idx, k_idx = index
@@ -38,26 +36,24 @@ def _Omega_RELATIVE_DIFFERENCE_GPU(theta_flat, index, values, device, gamma):
38
36
  abs_diff = torch.abs(diff)
39
37
  denom = theta_k + theta_j + gamma * abs_diff + 1e-8
40
38
  num = diff ** 2
41
-
39
+ psi_pair = num / denom
40
+ psi_pair = values * psi_pair
42
41
  # Compute gradient contributions
43
42
  dpsi = (2 * diff * denom - num * (1 + gamma * torch.sign(diff))) / (denom ** 2)
44
43
  grad_pair = values * (-dpsi)
45
-
46
44
  # Compute Hessian contributions
47
45
  d2psi = (2 * denom ** 2 - 4 * diff * denom * (1 + gamma * torch.sign(diff))
48
46
  + 2 * num * (1 + gamma * torch.sign(diff)) ** 2) / (denom ** 3 + 1e-8)
49
47
  hess_pair = values * d2psi
50
-
51
48
  # Initialize gradient and Hessian on the correct device
52
49
  grad_U = torch.zeros_like(theta_flat, device=device)
53
50
  hess_U = torch.zeros_like(theta_flat, device=device)
54
-
55
51
  # Accumulate gradient contributions
56
52
  grad_U.index_add_(0, j_idx, grad_pair)
57
53
  grad_U.index_add_(0, k_idx, -grad_pair)
58
-
59
54
  # Accumulate Hessian contributions
60
55
  hess_U.index_add_(0, j_idx, hess_pair)
61
56
  hess_U.index_add_(0, k_idx, hess_pair)
62
-
63
- return grad_U, hess_U
57
+ # Compute U_value
58
+ U_value = 0.5 * psi_pair.sum()
59
+ return grad_U, hess_U, U_value