AOT-biomaps 2.1.3__py3-none-any.whl → 2.9.233__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of AOT-biomaps might be problematic. Click here for more details.

Files changed (50) hide show
  1. AOT_biomaps/AOT_Acoustic/AcousticEnums.py +64 -0
  2. AOT_biomaps/AOT_Acoustic/AcousticTools.py +221 -0
  3. AOT_biomaps/AOT_Acoustic/FocusedWave.py +244 -0
  4. AOT_biomaps/AOT_Acoustic/IrregularWave.py +66 -0
  5. AOT_biomaps/AOT_Acoustic/PlaneWave.py +43 -0
  6. AOT_biomaps/AOT_Acoustic/StructuredWave.py +392 -0
  7. AOT_biomaps/AOT_Acoustic/__init__.py +15 -0
  8. AOT_biomaps/AOT_Acoustic/_mainAcoustic.py +978 -0
  9. AOT_biomaps/AOT_Experiment/Focus.py +55 -0
  10. AOT_biomaps/AOT_Experiment/Tomography.py +505 -0
  11. AOT_biomaps/AOT_Experiment/__init__.py +9 -0
  12. AOT_biomaps/AOT_Experiment/_mainExperiment.py +532 -0
  13. AOT_biomaps/AOT_Optic/Absorber.py +24 -0
  14. AOT_biomaps/AOT_Optic/Laser.py +70 -0
  15. AOT_biomaps/AOT_Optic/OpticEnums.py +17 -0
  16. AOT_biomaps/AOT_Optic/__init__.py +10 -0
  17. AOT_biomaps/AOT_Optic/_mainOptic.py +204 -0
  18. AOT_biomaps/AOT_Recon/AOT_Optimizers/DEPIERRO.py +191 -0
  19. AOT_biomaps/AOT_Recon/AOT_Optimizers/LS.py +106 -0
  20. AOT_biomaps/AOT_Recon/AOT_Optimizers/MAPEM.py +456 -0
  21. AOT_biomaps/AOT_Recon/AOT_Optimizers/MLEM.py +333 -0
  22. AOT_biomaps/AOT_Recon/AOT_Optimizers/PDHG.py +221 -0
  23. AOT_biomaps/AOT_Recon/AOT_Optimizers/__init__.py +5 -0
  24. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/Huber.py +90 -0
  25. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/Quadratic.py +86 -0
  26. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/RelativeDifferences.py +59 -0
  27. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/__init__.py +3 -0
  28. AOT_biomaps/AOT_Recon/AlgebraicRecon.py +1023 -0
  29. AOT_biomaps/AOT_Recon/AnalyticRecon.py +154 -0
  30. AOT_biomaps/AOT_Recon/BayesianRecon.py +230 -0
  31. AOT_biomaps/AOT_Recon/DeepLearningRecon.py +35 -0
  32. AOT_biomaps/AOT_Recon/PrimalDualRecon.py +210 -0
  33. AOT_biomaps/AOT_Recon/ReconEnums.py +375 -0
  34. AOT_biomaps/AOT_Recon/ReconTools.py +273 -0
  35. AOT_biomaps/AOT_Recon/__init__.py +11 -0
  36. AOT_biomaps/AOT_Recon/_mainRecon.py +288 -0
  37. AOT_biomaps/Config.py +95 -0
  38. AOT_biomaps/Settings.py +45 -13
  39. AOT_biomaps/__init__.py +271 -18
  40. aot_biomaps-2.9.233.dist-info/METADATA +22 -0
  41. aot_biomaps-2.9.233.dist-info/RECORD +43 -0
  42. {AOT_biomaps-2.1.3.dist-info → aot_biomaps-2.9.233.dist-info}/WHEEL +1 -1
  43. AOT_biomaps/AOT_Acoustic.py +0 -1881
  44. AOT_biomaps/AOT_Experiment.py +0 -541
  45. AOT_biomaps/AOT_Optic.py +0 -219
  46. AOT_biomaps/AOT_Reconstruction.py +0 -1416
  47. AOT_biomaps/config.py +0 -54
  48. AOT_biomaps-2.1.3.dist-info/METADATA +0 -20
  49. AOT_biomaps-2.1.3.dist-info/RECORD +0 -11
  50. {AOT_biomaps-2.1.3.dist-info → aot_biomaps-2.9.233.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,333 @@
1
+ from AOT_biomaps.AOT_Recon.ReconTools import _forward_projection, _backward_projection, check_gpu_memory, calculate_memory_requirement
2
+ from AOT_biomaps.Config import config
3
+ import numba
4
+ import torch
5
+ import numpy as np
6
+ import os
7
+ from tqdm import trange
8
+ import cupy as cp
9
+ import cupyx.scipy.sparse as cpsparse
10
+ import gc
11
+
12
+
13
+ def MLEM(
14
+ SMatrix,
15
+ y,
16
+ numIterations=100,
17
+ isSavingEachIteration=True,
18
+ withTumor=True,
19
+ device=None,
20
+ use_numba=False,
21
+ denominator_threshold=1e-6,
22
+ max_saves=5000,
23
+ show_logs=True,
24
+ useSparseSMatrix=True,
25
+ Z=350,
26
+ ):
27
+ """
28
+ Unified MLEM algorithm for Acousto-Optic Tomography.
29
+ Works on CPU (basic, multithread, optimized) and GPU (single or multi-GPU).
30
+ Args:
31
+ SMatrix: System matrix (shape: T, Z, X, N)
32
+ y: Measurement data (shape: T, N)
33
+ numIterations: Number of iterations
34
+ isSavingEachIteration: If True, saves intermediate results
35
+ withTumor: Boolean for description only
36
+ device: Torch device (auto-selected if None)
37
+ use_multi_gpu: If True and GPU available, uses all GPUs
38
+ use_numba: If True and on CPU, uses multithreaded Numba
39
+ max_saves: Maximum number of intermediate saves (default: 5000)
40
+ Returns:
41
+ Reconstructed image(s) and iteration indices (if isSavingEachIteration)
42
+ """
43
+ try:
44
+ tumor_str = "WITH" if withTumor else "WITHOUT"
45
+ # Auto-select device and method
46
+ if device is None:
47
+ if torch.cuda.is_available() and check_gpu_memory(config.select_best_gpu(), calculate_memory_requirement(SMatrix, y), show_logs=show_logs):
48
+ device = torch.device(f"cuda:{config.select_best_gpu()}")
49
+ use_gpu = True
50
+ else:
51
+ device = torch.device("cpu")
52
+ use_gpu = False
53
+ else:
54
+ use_gpu = device.type == "cuda"
55
+ # Dispatch to the appropriate implementation
56
+ if use_gpu:
57
+ if useSparseSMatrix:
58
+ return _MLEM_sparseCSR(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device.index, max_saves, denominator_threshold, Z, show_logs)
59
+ else:
60
+ return _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold,show_logs)
61
+ else:
62
+ if use_numba:
63
+ return _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs)
64
+ else:
65
+ return _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs)
66
+ except Exception as e:
67
+ print(f"Error in MLEM: {type(e).__name__}: {e}")
68
+ return None, None
69
+
70
+ def _MLEM_single_GPU(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, device, max_saves, denominator_threshold, show_logs=True):
71
+ try:
72
+ eps = torch.finfo(torch.float32).eps
73
+ T, Z, X, N = SMatrix.shape
74
+ ZX = Z * X
75
+ TN = T * N
76
+ A_flat = (
77
+ torch.from_numpy(SMatrix)
78
+ .to(device=device, dtype=torch.float32)
79
+ .permute(0, 3, 1, 2)
80
+ .contiguous()
81
+ .reshape(TN, ZX)
82
+ )
83
+ y_flat = torch.from_numpy(y).to(device=device, dtype=torch.float32).reshape(-1)
84
+ theta_flat = torch.ones(ZX, dtype=torch.float32, device=device)
85
+ norm_factor_flat = (
86
+ torch.from_numpy(SMatrix)
87
+ .to(device=device, dtype=torch.float32)
88
+ .sum(dim=(0, 3))
89
+ .reshape(-1)
90
+ )
91
+ description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
92
+ # Calculate save indices
93
+ if numIterations <= max_saves:
94
+ save_indices = list(range(numIterations))
95
+ else:
96
+ step = numIterations // max_saves
97
+ save_indices = list(range(0, numIterations, step))
98
+ if save_indices[-1] != numIterations - 1:
99
+ save_indices.append(numIterations - 1)
100
+ saved_theta = []
101
+ saved_indices = []
102
+ with torch.no_grad():
103
+ # Utilise range si show_logs=False, sinon trange
104
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
105
+ for it in iterator:
106
+ q_flat = A_flat @ theta_flat
107
+ # Appliquer le seuil : si q_flat < denominator_threshold, on met e_flat à 1 (comme dans le code C++)
108
+ mask = q_flat >= denominator_threshold
109
+ e_flat = torch.where(mask, y_flat / (q_flat + eps), torch.ones_like(q_flat))
110
+ c_flat = A_flat.T @ e_flat
111
+ theta_flat = (theta_flat / (norm_factor_flat + eps)) * c_flat
112
+ if isSavingEachIteration and it in save_indices:
113
+ saved_theta.append(theta_flat.reshape(Z, X).clone())
114
+ saved_indices.append(it)
115
+ # Free memory
116
+ del A_flat, y_flat, norm_factor_flat
117
+ torch.cuda.empty_cache()
118
+ if isSavingEachIteration:
119
+ return [t.cpu().numpy() for t in saved_theta], saved_indices
120
+ else:
121
+ return theta_flat.reshape(Z, X).cpu().numpy(), None
122
+ except Exception as e:
123
+ print(f"Error in single-GPU MLEM: {type(e).__name__}: {e}")
124
+ torch.cuda.empty_cache()
125
+ return None, None
126
+
127
+ def _MLEM_CPU_numba(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs=True):
128
+ try:
129
+ numba.set_num_threads(os.cpu_count())
130
+ q_p = np.zeros((SMatrix.shape[0], SMatrix.shape[3]), dtype=np.float32)
131
+ c_p = np.zeros((SMatrix.shape[1], SMatrix.shape[2]), dtype=np.float32)
132
+ theta_p_0 = np.ones((SMatrix.shape[1], SMatrix.shape[2]), dtype=np.float32)
133
+ matrix_theta = [theta_p_0]
134
+ saved_indices = [0]
135
+ normalization_factor = np.sum(SMatrix, axis=(0, 3)).astype(np.float32)
136
+
137
+ # Calculate save indices
138
+ if numIterations <= max_saves:
139
+ save_indices = list(range(numIterations))
140
+ else:
141
+ step = numIterations // max_saves
142
+ save_indices = list(range(0, numIterations, step))
143
+ if save_indices[-1] != numIterations - 1:
144
+ save_indices.append(numIterations - 1)
145
+
146
+ description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- processing on multithread CPU ({numba.config.NUMBA_DEFAULT_NUM_THREADS} threads) ----"
147
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
148
+
149
+ for it in iterator:
150
+ theta_p = matrix_theta[-1]
151
+ _forward_projection(SMatrix, theta_p, q_p)
152
+
153
+ # Appliquer le seuil : si q_p < denominator_threshold, on met e_p à 1
154
+ mask = q_p >= denominator_threshold
155
+ e_p = np.where(mask, y / (q_p + 1e-8), 1.0)
156
+
157
+ _backward_projection(SMatrix, e_p, c_p)
158
+ theta_p_plus_1 = theta_p / (normalization_factor + 1e-8) * c_p
159
+
160
+ if isSavingEachIteration and (it + 1) in save_indices:
161
+ matrix_theta.append(theta_p_plus_1)
162
+ saved_indices.append(it + 1)
163
+ else:
164
+ matrix_theta[-1] = theta_p_plus_1
165
+
166
+ if not isSavingEachIteration:
167
+ return matrix_theta[-1], None
168
+ else:
169
+ return matrix_theta, saved_indices
170
+ except Exception as e:
171
+ print(f"Error in Numba CPU MLEM: {type(e).__name__}: {e}")
172
+ return None, None
173
+
174
+ def _MLEM_CPU_opti(SMatrix, y, numIterations, isSavingEachIteration, tumor_str, max_saves, denominator_threshold, show_logs=True):
175
+ try:
176
+ T, Z, X, N = SMatrix.shape
177
+ A_flat = SMatrix.astype(np.float32).transpose(0, 3, 1, 2).reshape(T * N, Z * X)
178
+ y_flat = y.astype(np.float32).reshape(-1)
179
+ theta_0 = np.ones((Z, X), dtype=np.float32)
180
+ matrix_theta = [theta_0]
181
+ saved_indices = [0]
182
+ normalization_factor = np.sum(SMatrix, axis=(0, 3)).astype(np.float32)
183
+ normalization_factor_flat = normalization_factor.reshape(-1)
184
+
185
+ # Calculate save indices
186
+ if numIterations <= max_saves:
187
+ save_indices = list(range(numIterations))
188
+ else:
189
+ step = numIterations // max_saves
190
+ save_indices = list(range(0, numIterations, step))
191
+ if save_indices[-1] != numIterations - 1:
192
+ save_indices.append(numIterations - 1)
193
+
194
+ description = f"AOT-BioMaps -- ML-EM ---- {tumor_str} TUMOR ---- processing on single CPU (optimized) ----"
195
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
196
+
197
+ for it in iterator:
198
+ theta_p = matrix_theta[-1]
199
+ theta_p_flat = theta_p.reshape(-1)
200
+ q_flat = A_flat @ theta_p_flat
201
+
202
+ # Appliquer le seuil : si q_flat < denominator_threshold, on met e_flat à 1
203
+ mask = q_flat >= denominator_threshold
204
+ e_flat = np.where(mask, y_flat / (q_flat + np.finfo(np.float32).tiny), 1.0)
205
+
206
+ c_flat = A_flat.T @ e_flat
207
+ theta_p_plus_1_flat = theta_p_flat / (normalization_factor_flat + np.finfo(np.float32).tiny) * c_flat
208
+ theta_p_plus_1 = theta_p_plus_1_flat.reshape(Z, X)
209
+
210
+ if isSavingEachIteration and (it + 1) in save_indices:
211
+ matrix_theta.append(theta_p_plus_1)
212
+ saved_indices.append(it + 1)
213
+ else:
214
+ matrix_theta[-1] = theta_p_plus_1
215
+
216
+ if not isSavingEachIteration:
217
+ return matrix_theta[-1], None
218
+ else:
219
+ return matrix_theta, saved_indices
220
+ except Exception as e:
221
+ print(f"Error in optimized CPU MLEM: {type(e).__name__}: {e}")
222
+ return None, None
223
+
224
+
225
+
226
+ def _MLEM_sparseCSR(
227
+ SMatrix,
228
+ y,
229
+ numIterations,
230
+ isSavingEachIteration,
231
+ tumor_str,
232
+ device_index,
233
+ max_saves,
234
+ denominator_threshold,
235
+ Z,
236
+ show_logs=True,
237
+ ):
238
+ """
239
+ MLEM implementation using CuPy with sparse CSR matrix on a single GPU.
240
+ Caution : SMatrix must be a cupyx.scipy.sparse.csr_matrix please sparse it before using.
241
+ """
242
+ try:
243
+ cp.cuda.Device(device_index).use()
244
+ dtype = cp.float32
245
+ eps = cp.finfo(dtype).eps
246
+
247
+ # --- Préparation de la matrice et des données ---
248
+ if not isinstance(SMatrix, cpsparse.csr_matrix):
249
+ SMatrix = cpsparse.csr_matrix(SMatrix, dtype=dtype)
250
+ else:
251
+ SMatrix = SMatrix.astype(dtype)
252
+
253
+ if not isinstance(y, cp.ndarray):
254
+ y_cupy = cp.asarray(y, dtype=dtype)
255
+ else:
256
+ y_cupy = y.astype(dtype)
257
+
258
+ TN, ZX = SMatrix.shape
259
+ X = ZX // Z
260
+
261
+ # Initialisation du volume reconstruit
262
+ theta_flat = cp.full(ZX, 0.1, dtype=dtype)
263
+
264
+ # Facteur de normalisation
265
+ norm_factor = cp.maximum(SMatrix.sum(axis=0).ravel(), 1e-6)
266
+ norm_factor_inv = 1.0 / norm_factor
267
+
268
+ # Gestion des indices de sauvegarde
269
+ if numIterations <= max_saves:
270
+ save_indices = list(range(numIterations))
271
+ else:
272
+ step = max(1, numIterations // max_saves)
273
+ save_indices = list(range(0, numIterations, step))
274
+ if save_indices[-1] != numIterations - 1:
275
+ save_indices.append(numIterations - 1)
276
+
277
+ saved_theta = []
278
+ saved_indices = []
279
+
280
+ description = f"AOT-BioMaps -- ML-EM (sparse CSR) ---- {tumor_str} TUMOR ---- GPU {device_index}"
281
+
282
+ iterator = trange(numIterations, desc=description, ncols=100) if show_logs else range(numIterations)
283
+
284
+ # --- Boucle principale MLEM ---
285
+ for it in iterator:
286
+ # Étape 1 : Projection
287
+ q_flat = SMatrix.dot(theta_flat)
288
+ q_flat = cp.maximum(q_flat, denominator_threshold)
289
+
290
+ # Étape 2 : Rapport y / (A*L)
291
+ e_flat = y_cupy / q_flat
292
+
293
+ # Étape 3 : Rétroprojection (A.T * e)
294
+ c_flat = SMatrix.T.dot(e_flat)
295
+
296
+ # Étape 4 : Mise à jour
297
+ theta_flat = theta_flat * (norm_factor_inv * c_flat)
298
+ theta_flat = cp.maximum(theta_flat, 0)
299
+
300
+ # Sauvegarde éventuelle
301
+ if isSavingEachIteration and it in save_indices:
302
+ saved_theta.append(theta_flat.reshape(Z, X).get()) # transfert CPU
303
+ saved_indices.append(it)
304
+
305
+ # Libération mémoire GPU
306
+ del q_flat, e_flat, c_flat
307
+ cp.get_default_memory_pool().free_all_blocks()
308
+ gc.collect()
309
+
310
+ # Vérif convergence toutes les 10 itérations
311
+ if it % 10 == 0 and show_logs:
312
+ rel_change = cp.abs(theta_flat - theta_flat).max() / (theta_flat.max() + eps)
313
+ if rel_change < 1e-4:
314
+ print(f"Convergence atteinte à l’itération {it}")
315
+ break
316
+
317
+ # --- Fin : récupération du résultat ---
318
+ result = theta_flat.reshape(Z, X).get() # Retour sur CPU
319
+ del theta_flat, norm_factor, norm_factor_inv
320
+ cp.get_default_memory_pool().free_all_blocks()
321
+ gc.collect()
322
+
323
+ if isSavingEachIteration:
324
+ return saved_theta, saved_indices
325
+ else:
326
+ return result, None
327
+
328
+ except Exception as e:
329
+ print(f"Erreur dans _MLEM_single_GPU_CuPy: {type(e).__name__}: {e}")
330
+ cp.get_default_memory_pool().free_all_blocks()
331
+ gc.collect()
332
+ return None, None
333
+
@@ -0,0 +1,221 @@
1
+ from AOT_biomaps.AOT_Recon.ReconTools import power_method, gradient, div, proj_l2, prox_G, prox_F_star
2
+ from AOT_biomaps.Config import config
3
+ from AOT_biomaps.AOT_Recon.ReconEnums import NoiseType
4
+ import torch
5
+ from tqdm import trange
6
+
7
+ '''
8
+ This module implements Primal-Dual Hybrid Gradient (PDHG) methods for solving inverse problems in Acousto-Optic Tomography.
9
+ It includes Chambolle-Pock algorithms for Total Variation (TV) and Kullback-Leibler (KL) divergence regularization.
10
+ The methods can run on both CPU and GPU, with configurations set in the AOT_biomaps.Config module.
11
+ '''
12
+
13
+ def CP_TV(
14
+ SMatrix,
15
+ y,
16
+ alpha=1e-1,
17
+ theta=1.0,
18
+ numIterations=5000,
19
+ isSavingEachIteration=True,
20
+ L=None,
21
+ withTumor=True,
22
+ device=None,
23
+ max_saves=5000,
24
+ ):
25
+ """
26
+ Chambolle-Pock algorithm for Total Variation (TV) regularization.
27
+ Works on both CPU and GPU.
28
+ Args:
29
+ SMatrix: System matrix (shape: T, Z, X, N)
30
+ y: Measurement data (shape: T, N)
31
+ alpha: Regularization parameter for TV
32
+ theta: Relaxation parameter (1.0 for standard Chambolle-Pock)
33
+ numIterations: Number of iterations
34
+ isSavingEachIteration: If True, returns selected intermediate reconstructions
35
+ L: Lipschitz constant (estimated if None)
36
+ withTumor: Boolean for description only
37
+ device: Torch device (auto-selected if None)
38
+ max_saves: Maximum number of intermediate saves (default: 5000)
39
+ """
40
+ # Auto-select device if not provided
41
+ if device is None:
42
+ device = torch.device(f"cuda:{config.select_best_gpu()}" if torch.cuda.is_available() else "cpu")
43
+ else:
44
+ device = torch.device(device)
45
+
46
+ # Convert data to tensors and move to device
47
+ A = torch.tensor(SMatrix, dtype=torch.float32, device=device)
48
+ y = torch.tensor(y, dtype=torch.float32, device=device)
49
+ T, Z, X, N = SMatrix.shape
50
+ A_flat = A.permute(0, 3, 1, 2).reshape(T * N, Z * X)
51
+ y_flat = y.reshape(-1)
52
+
53
+ # Robust normalization
54
+ norm_A = A_flat.abs().max().clamp(min=1e-8)
55
+ norm_y = y_flat.abs().max().clamp(min=1e-8)
56
+ A_flat = A_flat / norm_A
57
+ y_flat = y_flat / norm_y
58
+
59
+ # Define forward/backward operators
60
+ P = lambda x: torch.matmul(A_flat, x)
61
+ PT = lambda y: torch.matmul(A_flat.T, y)
62
+
63
+ # Estimate Lipschitz constant if needed
64
+ if L is None:
65
+ try:
66
+ L = power_method(P, PT, y_flat, Z, X)
67
+ L = max(L, 1e-3)
68
+ except:
69
+ L = 1.0
70
+
71
+ sigma = 1.0 / L
72
+ tau = 1.0 / L
73
+
74
+ # Initialize variables
75
+ x = torch.zeros(Z * X, device=device)
76
+ p = torch.zeros((2, Z, X), device=device)
77
+ q = torch.zeros_like(y_flat)
78
+ x_tilde = x.clone()
79
+
80
+ # Calculate save indices
81
+ if numIterations <= max_saves:
82
+ save_indices = list(range(numIterations))
83
+ else:
84
+ step = numIterations // max_saves
85
+ save_indices = list(range(0, numIterations, step))
86
+ if save_indices[-1] != numIterations - 1:
87
+ save_indices.append(numIterations - 1)
88
+
89
+ I_reconMatrix = []
90
+ saved_indices = []
91
+
92
+ # Description for progress bar
93
+ tumor_str = "WITH TUMOR" if withTumor else "WITHOUT TUMOR"
94
+ device_str = f"GPU no.{torch.cuda.current_device()}" if device.type == "cuda" else "CPU"
95
+ description = f"AOT-BioMaps -- Primal/Dual Reconstruction (TV) α:{alpha:.4f} L:{L:.4f} -- {tumor_str} -- {device_str}"
96
+
97
+ # Main loop
98
+ for iteration in trange(numIterations, desc=description):
99
+ # Update p (TV proximal step)
100
+ grad_x = gradient(x_tilde.reshape(Z, X))
101
+ p = proj_l2(p + sigma * grad_x, alpha)
102
+
103
+ # Update q (data fidelity)
104
+ q = (q + sigma * (P(x_tilde) - y_flat)) / (1 + sigma)
105
+
106
+ # Update x
107
+ x_old = x.clone()
108
+ div_p = div(p).ravel() # Divergence calculation
109
+ ATq = PT(q)
110
+ x = (x - tau * (ATq - div_p)) / (1 + tau * 1e-6) # Light L2 regularization
111
+
112
+ # Update x_tilde
113
+ x_tilde = x + theta * (x - x_old)
114
+
115
+ # Save intermediate result if needed
116
+ if isSavingEachIteration and iteration in save_indices:
117
+ I_reconMatrix.append(x.reshape(Z, X).clone() * (norm_y / norm_A))
118
+ saved_indices.append(iteration)
119
+
120
+ # Return results
121
+ if isSavingEachIteration:
122
+ return [tensor.cpu().numpy() for tensor in I_reconMatrix], saved_indices
123
+ else:
124
+ return (x.reshape(Z, X) * (norm_y / norm_A)).cpu().numpy(), None
125
+
126
+
127
+ def CP_KL(
128
+ SMatrix,
129
+ y,
130
+ alpha=1e-9,
131
+ theta=1.0,
132
+ numIterations=5000,
133
+ isSavingEachIteration=True,
134
+ L=None,
135
+ withTumor=True,
136
+ device=None,
137
+ max_saves=5000,
138
+ ):
139
+ """
140
+ Chambolle-Pock algorithm for Kullback-Leibler (KL) divergence regularization.
141
+ Works on both CPU and GPU.
142
+ Args:
143
+ SMatrix: System matrix (shape: T, Z, X, N)
144
+ y: Measurement data (shape: T, X, N)
145
+ alpha: Regularization parameter
146
+ theta: Relaxation parameter (1.0 for standard Chambolle-Pock)
147
+ numIterations: Number of iterations
148
+ isSavingEachIteration: If True, returns selected intermediate reconstructions
149
+ L: Lipschitz constant (estimated if None)
150
+ withTumor: Boolean for description only
151
+ device: Torch device (auto-selected if None)
152
+ max_saves: Maximum number of intermediate saves (default: 5000)
153
+ """
154
+ # Auto-select device if not provided
155
+ if device is None:
156
+ device = torch.device(f"cuda:{config.select_best_gpu()}" if torch.cuda.is_available() else "cpu")
157
+ else:
158
+ device = torch.device(device)
159
+
160
+ # Convert data to tensors and move to device
161
+ A = torch.tensor(SMatrix, dtype=torch.float32, device=device)
162
+ y = torch.tensor(y, dtype=torch.float32, device=device)
163
+ T, Z, X, N = SMatrix.shape
164
+ A_flat = A.permute(0, 3, 1, 2).reshape(T * N, Z * X)
165
+ y_flat = y.reshape(-1)
166
+
167
+ # Define forward/backward operators
168
+ P = lambda x: torch.matmul(A_flat, x.ravel())
169
+ PT = lambda y: torch.matmul(A_flat.T, y)
170
+
171
+ # Estimate Lipschitz constant if needed
172
+ if L is None:
173
+ L = power_method(P, PT, y_flat, Z, X)
174
+
175
+ sigma = 1.0 / L
176
+ tau = 1.0 / L
177
+
178
+ # Initialize variables
179
+ x = torch.zeros(Z * X, device=device)
180
+ q = torch.zeros_like(y_flat)
181
+ x_tilde = x.clone()
182
+
183
+ # Calculate save indices
184
+ if numIterations <= max_saves:
185
+ save_indices = list(range(numIterations))
186
+ else:
187
+ step = numIterations // max_saves
188
+ save_indices = list(range(0, numIterations, step))
189
+ if save_indices[-1] != numIterations - 1:
190
+ save_indices.append(numIterations - 1)
191
+
192
+ I_reconMatrix = [x.reshape(Z, X).cpu().numpy()]
193
+ saved_indices = [0]
194
+
195
+ # Description for progress bar
196
+ tumor_str = "WITH TUMOR" if withTumor else "WITHOUT TUMOR"
197
+ device_str = f"GPU no.{torch.cuda.current_device()}" if device.type == "cuda" else "CPU"
198
+ description = f"AOT-BioMaps -- Primal/Dual Reconstruction (KL) α:{alpha:.4f} L:{L:.4f} -- {tumor_str} -- {device_str}"
199
+
200
+ # Main loop
201
+ for iteration in trange(numIterations, desc=description):
202
+ # Update q (proximal step for F*)
203
+ q = prox_F_star(q + sigma * P(x_tilde) - sigma * y_flat, sigma, y_flat)
204
+
205
+ # Update x (proximal step for G)
206
+ x_old = x.clone()
207
+ x = prox_G(x - tau * PT(q), tau, PT(torch.ones_like(y_flat)))
208
+
209
+ # Update x_tilde
210
+ x_tilde = x + theta * (x - x_old)
211
+
212
+ # Save intermediate result if needed
213
+ if isSavingEachIteration and iteration in save_indices:
214
+ I_reconMatrix.append(x.reshape(Z, X).cpu().numpy())
215
+ saved_indices.append(iteration)
216
+
217
+ # Return results
218
+ if isSavingEachIteration:
219
+ return I_reconMatrix, saved_indices
220
+ else:
221
+ return I_reconMatrix[-1], None
@@ -0,0 +1,5 @@
1
+ from .DEPIERRO import *
2
+ from .MAPEM import *
3
+ from .MLEM import *
4
+ from .PDHG import *
5
+ from .LS import *
@@ -0,0 +1,90 @@
1
+ import numpy as np
2
+ import torch
3
+ from numba import njit
4
+
5
+ @njit
6
+ def _Omega_HUBER_PIECEWISE_CPU(theta_flat, index, values, delta):
7
+ """
8
+ Compute the gradient and Hessian of the Huber penalty function for sparse data.
9
+ Parameters:
10
+ theta_flat (torch.Tensor): Flattened parameter vector.
11
+ index (torch.Tensor): Indices of the sparse matrix in COO format.
12
+ values (torch.Tensor): Values of the sparse matrix in COO format.
13
+ delta (float): Threshold for the Huber penalty.
14
+ Returns:
15
+ grad_U (torch.Tensor): Gradient of the penalty function.
16
+ hess_U (torch.Tensor): Hessian of the penalty function.
17
+ U_value (float): Value of the penalty function.
18
+ """
19
+ j_idx, k_idx = index
20
+ diff = theta_flat[j_idx] - theta_flat[k_idx]
21
+ abs_diff = np.abs(diff)
22
+
23
+ # Huber penalty (potential function)
24
+ psi_pair = np.where(abs_diff > delta,
25
+ delta * abs_diff - 0.5 * delta ** 2,
26
+ 0.5 * diff ** 2)
27
+ psi_pair = values * psi_pair
28
+
29
+ # Huber gradient
30
+ grad_pair = np.where(abs_diff > delta,
31
+ delta * np.sign(diff),
32
+ diff)
33
+ grad_pair = values * grad_pair
34
+
35
+ # Huber Hessian
36
+ hess_pair = np.where(abs_diff > delta,
37
+ np.zeros_like(diff),
38
+ np.ones_like(diff))
39
+ hess_pair = values * hess_pair
40
+
41
+ grad_U = np.zeros_like(theta_flat)
42
+ hess_U = np.zeros_like(theta_flat)
43
+
44
+ np.add.at(grad_U, j_idx, grad_pair)
45
+ np.add.at(hess_U, j_idx, hess_pair)
46
+
47
+ # Total penalty energy
48
+ U_value = 0.5 * np.sum(psi_pair)
49
+
50
+ return grad_U, hess_U, U_value
51
+
52
+ def _Omega_HUBER_PIECEWISE_GPU(theta_flat, index, values, delta, device):
53
+ """
54
+ Compute the gradient and Hessian of the Huber penalty function for sparse data.
55
+ Parameters:
56
+ theta_flat (torch.Tensor): Flattened parameter vector.
57
+ index (torch.Tensor): Indices of the sparse matrix in COO format.
58
+ values (torch.Tensor): Values of the sparse matrix in COO format.
59
+ delta (float): Threshold for the Huber penalty.
60
+ Returns:
61
+ grad_U (torch.Tensor): Gradient of the penalty function.
62
+ hess_U (torch.Tensor): Hessian of the penalty function.
63
+ U_value (float): Value of the penalty function.
64
+ """
65
+ j_idx, k_idx = index
66
+ diff = theta_flat[j_idx] - theta_flat[k_idx]
67
+ abs_diff = torch.abs(diff)
68
+
69
+ # Huber penalty
70
+ psi_pair = torch.where(abs_diff > delta, delta * abs_diff - 0.5 * delta ** 2, 0.5 * diff ** 2)
71
+ psi_pair = values * psi_pair
72
+
73
+ grad_pair = torch.where(abs_diff > delta, delta * torch.sign(diff), diff)
74
+ grad_pair = values * grad_pair
75
+
76
+ hess_pair = torch.where(abs_diff > delta, torch.zeros_like(diff), torch.ones_like(diff))
77
+ hess_pair = values * hess_pair
78
+
79
+ grad_U = torch.zeros_like(theta_flat, device=device)
80
+ hess_U = torch.zeros_like(theta_flat, device=device)
81
+
82
+ grad_U.index_add_(0, j_idx, grad_pair)
83
+ grad_U.index_add_(0, k_idx, -grad_pair)
84
+
85
+ hess_U.index_add_(0, j_idx, hess_pair)
86
+ hess_U.index_add_(0, k_idx, hess_pair)
87
+
88
+ U_value = 0.5 * psi_pair.sum()
89
+
90
+ return grad_U, hess_U, U_value