AOT-biomaps 2.1.3__py3-none-any.whl → 2.9.233__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of AOT-biomaps might be problematic. Click here for more details.

Files changed (50) hide show
  1. AOT_biomaps/AOT_Acoustic/AcousticEnums.py +64 -0
  2. AOT_biomaps/AOT_Acoustic/AcousticTools.py +221 -0
  3. AOT_biomaps/AOT_Acoustic/FocusedWave.py +244 -0
  4. AOT_biomaps/AOT_Acoustic/IrregularWave.py +66 -0
  5. AOT_biomaps/AOT_Acoustic/PlaneWave.py +43 -0
  6. AOT_biomaps/AOT_Acoustic/StructuredWave.py +392 -0
  7. AOT_biomaps/AOT_Acoustic/__init__.py +15 -0
  8. AOT_biomaps/AOT_Acoustic/_mainAcoustic.py +978 -0
  9. AOT_biomaps/AOT_Experiment/Focus.py +55 -0
  10. AOT_biomaps/AOT_Experiment/Tomography.py +505 -0
  11. AOT_biomaps/AOT_Experiment/__init__.py +9 -0
  12. AOT_biomaps/AOT_Experiment/_mainExperiment.py +532 -0
  13. AOT_biomaps/AOT_Optic/Absorber.py +24 -0
  14. AOT_biomaps/AOT_Optic/Laser.py +70 -0
  15. AOT_biomaps/AOT_Optic/OpticEnums.py +17 -0
  16. AOT_biomaps/AOT_Optic/__init__.py +10 -0
  17. AOT_biomaps/AOT_Optic/_mainOptic.py +204 -0
  18. AOT_biomaps/AOT_Recon/AOT_Optimizers/DEPIERRO.py +191 -0
  19. AOT_biomaps/AOT_Recon/AOT_Optimizers/LS.py +106 -0
  20. AOT_biomaps/AOT_Recon/AOT_Optimizers/MAPEM.py +456 -0
  21. AOT_biomaps/AOT_Recon/AOT_Optimizers/MLEM.py +333 -0
  22. AOT_biomaps/AOT_Recon/AOT_Optimizers/PDHG.py +221 -0
  23. AOT_biomaps/AOT_Recon/AOT_Optimizers/__init__.py +5 -0
  24. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/Huber.py +90 -0
  25. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/Quadratic.py +86 -0
  26. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/RelativeDifferences.py +59 -0
  27. AOT_biomaps/AOT_Recon/AOT_PotentialFunctions/__init__.py +3 -0
  28. AOT_biomaps/AOT_Recon/AlgebraicRecon.py +1023 -0
  29. AOT_biomaps/AOT_Recon/AnalyticRecon.py +154 -0
  30. AOT_biomaps/AOT_Recon/BayesianRecon.py +230 -0
  31. AOT_biomaps/AOT_Recon/DeepLearningRecon.py +35 -0
  32. AOT_biomaps/AOT_Recon/PrimalDualRecon.py +210 -0
  33. AOT_biomaps/AOT_Recon/ReconEnums.py +375 -0
  34. AOT_biomaps/AOT_Recon/ReconTools.py +273 -0
  35. AOT_biomaps/AOT_Recon/__init__.py +11 -0
  36. AOT_biomaps/AOT_Recon/_mainRecon.py +288 -0
  37. AOT_biomaps/Config.py +95 -0
  38. AOT_biomaps/Settings.py +45 -13
  39. AOT_biomaps/__init__.py +271 -18
  40. aot_biomaps-2.9.233.dist-info/METADATA +22 -0
  41. aot_biomaps-2.9.233.dist-info/RECORD +43 -0
  42. {AOT_biomaps-2.1.3.dist-info → aot_biomaps-2.9.233.dist-info}/WHEEL +1 -1
  43. AOT_biomaps/AOT_Acoustic.py +0 -1881
  44. AOT_biomaps/AOT_Experiment.py +0 -541
  45. AOT_biomaps/AOT_Optic.py +0 -219
  46. AOT_biomaps/AOT_Reconstruction.py +0 -1416
  47. AOT_biomaps/config.py +0 -54
  48. AOT_biomaps-2.1.3.dist-info/METADATA +0 -20
  49. AOT_biomaps-2.1.3.dist-info/RECORD +0 -11
  50. {AOT_biomaps-2.1.3.dist-info → aot_biomaps-2.9.233.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
1
+ from .Laser import Laser
2
+ from .Absorber import Absorber
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.patches as patches
6
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
7
+
8
+ class Phantom:
9
+ """
10
+ Class to apply absorbers to a laser field in the XZ plane.
11
+ """
12
+
13
+ def __init__(self, params):
14
+ """
15
+ Initializes the phantom with the given parameters.
16
+ :param params: Configuration parameters for the phantom.
17
+ """
18
+ try:
19
+ absorber_params = params.optic['absorbers']
20
+ self.absorbers = [Absorber(**a) for a in absorber_params] if absorber_params else []
21
+ self.laser = Laser(params)
22
+ self.phantom = self._apply_absorbers()
23
+ self.phantom = np.transpose(self.phantom)
24
+ self.laser.intensity = np.transpose(self.laser.intensity)
25
+ self.maskList = None # List to store ROI masks
26
+ except KeyError as e:
27
+ raise ValueError(f"Missing parameter: {e}")
28
+ except Exception as e:
29
+ raise RuntimeError(f"Error initializing Phantom: {e}")
30
+
31
+ def __str__(self):
32
+ """
33
+ Returns a string representation of the Phantom object,
34
+ including its laser and absorber parameters.
35
+ :return: String representing the Phantom object.
36
+ """
37
+ try:
38
+ # Laser attributes
39
+ laser_attrs = {
40
+ 'shape': self.laser.shape.name.capitalize(),
41
+ 'center': self.laser.center,
42
+ 'w0': self.laser.w0,
43
+ }
44
+ laser_attr_lines = [f" {k}: {v}" for k, v in laser_attrs.items()]
45
+
46
+ # Absorber attributes
47
+ absorber_lines = []
48
+ for absorber in self.absorbers:
49
+ absorber_lines.append(f" - name: \"{absorber.name}\"")
50
+ absorber_lines.append(f" type: \"{absorber.type}\"")
51
+ absorber_lines.append(f" center: {absorber.center}")
52
+ absorber_lines.append(f" radius: {absorber.radius}")
53
+ absorber_lines.append(f" amplitude: {absorber.amplitude}")
54
+
55
+ # Define borders and titles
56
+ border = "+" + "-" * 40 + "+"
57
+ title = f"| Type: {self.__class__.__name__} |"
58
+ laser_title = "| Laser Parameters |"
59
+ absorber_title = "| Absorbers |"
60
+
61
+ # Assemble the final result
62
+ result = f"{border}\n{title}\n{border}\n{laser_title}\n{border}\n"
63
+ result += "\n".join(laser_attr_lines)
64
+ result += f"\n{border}\n{absorber_title}\n{border}\n"
65
+ result += "\n".join(absorber_lines)
66
+ result += f"\n{border}"
67
+
68
+ return result
69
+ except Exception as e:
70
+ raise RuntimeError(f"Error generating string representation: {e}")
71
+
72
+ def find_ROI(self):
73
+ """
74
+ Computes binary masks for each ROI and stores them in self.maskList.
75
+ :return: True if pixels are detected in any ROI, False otherwise.
76
+ """
77
+ try:
78
+ X_mm, Z_mm = np.meshgrid(self.laser.x, self.laser.z, indexing='xy')
79
+ assert self.phantom.shape == X_mm.shape, (
80
+ f"Shape mismatch: phantom={self.phantom.shape}, grid={X_mm.shape}"
81
+ )
82
+ self.maskList = [] # Reset the list
83
+ roi_found = False
84
+
85
+ for absorber in self.absorbers:
86
+ center_x_mm = absorber.center[0] * 1000 # Convert to mm
87
+ center_z_mm = absorber.center[1] * 1000 # Convert to mm
88
+ radius_mm = absorber.radius * 1000 # Convert to mm
89
+
90
+ # Create mask for this ROI
91
+ mask_i = (X_mm - center_x_mm)**2 + (Z_mm - center_z_mm)**2 <= radius_mm**2
92
+ self.maskList.append(mask_i)
93
+
94
+ except Exception as e:
95
+ raise RuntimeError(f"Error in find_ROI: {e}")
96
+
97
+ def _apply_absorbers(self):
98
+ """
99
+ Applies the absorbers to the laser field.
100
+ :return: Intensity matrix of the phantom with applied absorbers.
101
+ """
102
+ try:
103
+ X, Z = np.meshgrid(self.laser.x, self.laser.z, indexing='ij')
104
+ intensity = np.copy(self.laser.intensity)
105
+
106
+ for absorber in self.absorbers:
107
+ r2 = (X - absorber.center[0] * 1000)**2 + (Z - absorber.center[1] * 1000)**2
108
+ absorption = -absorber.amplitude * np.exp(-r2 / (absorber.radius * 1000)**2)
109
+ intensity += absorption
110
+
111
+ return np.clip(intensity, 0, None)
112
+ except Exception as e:
113
+ raise RuntimeError(f"Error applying absorbers: {e}")
114
+
115
+ def show_phantom(self):
116
+ """
117
+ Displays the optical phantom with absorbers.
118
+ """
119
+ try:
120
+ plt.figure(figsize=(6, 6))
121
+ plt.imshow(
122
+ self.phantom,
123
+ extent=(self.laser.x[0], self.laser.x[-1] + 1, self.laser.z[-1], self.laser.z[0]),
124
+ aspect='equal',
125
+ cmap='hot'
126
+ )
127
+ plt.colorbar(label='Intensity')
128
+ plt.xlabel('X (mm)', fontsize=20)
129
+ plt.ylabel('Z (mm)', fontsize=20)
130
+ plt.tick_params(axis='both', which='major', labelsize=20)
131
+ plt.title('Optical Phantom with Absorbers')
132
+ plt.show()
133
+ except Exception as e:
134
+ raise RuntimeError(f"Error plotting phantom: {e}")
135
+
136
+ def show_ROI(self):
137
+ """
138
+ Displays the optical image with ROIs and average intensities.
139
+ Calls find_ROI() if self.maskList is empty.
140
+ """
141
+ try:
142
+ if not self.maskList:
143
+ self.find_ROI()
144
+
145
+ fig, ax = plt.subplots(figsize=(6, 6))
146
+ im = ax.imshow(
147
+ self.phantom,
148
+ extent=(
149
+ np.min(self.laser.x), np.max(self.laser.x),
150
+ np.max(self.laser.z), np.min(self.laser.z)
151
+ ),
152
+ aspect='equal',
153
+ cmap='hot'
154
+ )
155
+ divider = make_axes_locatable(ax)
156
+ cax = divider.append_axes("right", size="5%", pad=0.05)
157
+ plt.colorbar(im, cax=cax, label='Intensity')
158
+
159
+ # Draw ROIs
160
+ for i, absorber in enumerate(self.absorbers):
161
+ center_x_mm = absorber.center[0] * 1000 # Convert to mm
162
+ center_z_mm = absorber.center[1] * 1000 # Convert to mm
163
+ radius_mm = absorber.radius * 1000 # Convert to mm
164
+
165
+ circle = patches.Circle(
166
+ (center_x_mm, center_z_mm),
167
+ radius_mm,
168
+ edgecolor='limegreen',
169
+ facecolor='none',
170
+ linewidth=2
171
+ )
172
+ ax.add_patch(circle)
173
+ ax.text(
174
+ center_x_mm,
175
+ center_z_mm - 2,
176
+ str(i + 1),
177
+ color='limegreen',
178
+ ha='center',
179
+ va='center',
180
+ fontsize=12,
181
+ fontweight='bold'
182
+ )
183
+
184
+ # Global mask (union of all ROIs)
185
+ ROI_mask = np.zeros_like(self.phantom, dtype=bool)
186
+ for mask in self.maskList:
187
+ ROI_mask |= mask
188
+
189
+ roi_values = self.phantom[ROI_mask]
190
+ if roi_values.size == 0:
191
+ print("❌ NO PIXELS IN ROIs! Check positions:")
192
+ for i, abs in enumerate(self.absorbers):
193
+ print(f" Absorber {i}: center=({abs.center[0]*1000:.3f}, {abs.center[1]*1000:.3f}) mm")
194
+ print(f" radius={abs.radius*1000:.3f} mm")
195
+ else:
196
+ print(f"✅ Average intensity in ROIs: {np.mean(roi_values):.4f}")
197
+
198
+ ax.set_xlabel('x (mm)')
199
+ ax.set_ylabel('z (mm)')
200
+ ax.set_title('Phantom with ROIs')
201
+ plt.tight_layout()
202
+ plt.show()
203
+ except Exception as e:
204
+ raise RuntimeError(f"Error in show_ROI: {e}")
@@ -0,0 +1,191 @@
1
+ from AOT_biomaps.AOT_Recon.ReconEnums import PotentialType
2
+ from AOT_biomaps.AOT_Recon.ReconTools import _build_adjacency_sparse, calculate_memory_requirement, check_gpu_memory
3
+ from AOT_biomaps.Config import config
4
+
5
+ import warnings
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import trange
9
+
10
+ if config.get_process() == 'gpu':
11
+ try:
12
+ from torch_scatter import scatter
13
+ except ImportError:
14
+ raise ImportError("torch_scatter and torch_sparse are required for GPU processing. Please install them using 'pip install torch-scatter torch-sparse' with correct link (follow instructions https://github.com/LucasDuclos/AcoustoOpticTomography/edit/main/README.md).")
15
+
16
+ def DEPIERRO(
17
+ SMatrix,
18
+ y,
19
+ numIterations,
20
+ beta,
21
+ sigma,
22
+ isSavingEachIteration,
23
+ withTumor,
24
+ max_saves,
25
+ show_logs):
26
+ """
27
+ This method implements the DEPIERRO algorithm using either CPU or single-GPU PyTorch acceleration.
28
+ Multi-GPU and Multi-CPU modes are not implemented for this algorithm.
29
+ """
30
+ try:
31
+ tumor_str = "WITH" if withTumor else "WITHOUT"
32
+ # Auto-select device and method
33
+ if device is None:
34
+ if torch.cuda.is_available() and check_gpu_memory(config.select_best_gpu(), calculate_memory_requirement(SMatrix, y), show_logs=show_logs):
35
+ device = torch.device(f"cuda:{config.select_best_gpu()}")
36
+ use_gpu = True
37
+ else:
38
+ device = torch.device("cpu")
39
+ use_gpu = False
40
+ else:
41
+ use_gpu = device.type == "cuda"
42
+ # Dispatch to the appropriate implementation
43
+ if use_gpu:
44
+ return _DEPIERRO_GPU(SMatrix, y, numIterations, beta, sigma, isSavingEachIteration, tumor_str, device, max_saves, show_logs)
45
+ else:
46
+ return _DEPIERRO_CPU(SMatrix, y, numIterations, beta, sigma, isSavingEachIteration, tumor_str, device, max_saves, show_logs)
47
+ except Exception as e:
48
+ print(f"Error in MLEM: {type(e).__name__}: {e}")
49
+ return None, None
50
+
51
+ def _DEPIERRO_GPU(SMatrix, y, numIterations, beta, sigma, isSavingEachIteration, tumor_str, device, max_saves, show_logs=True):
52
+ # Conversion des données en tenseurs PyTorch (float64)
53
+ A_matrix_torch = torch.tensor(SMatrix, dtype=torch.float64, device=device)
54
+ y_torch = torch.tensor(y, dtype=torch.float64, device=device)
55
+ # Dimensions
56
+ T, Z, X, N = SMatrix.shape
57
+ J = Z * X
58
+ # Redimensionnement des matrices
59
+ A_flat = A_matrix_torch.permute(0, 3, 1, 2).reshape(T * N, J)
60
+ y_flat = y_torch.reshape(-1)
61
+ # Initialisation de theta
62
+ theta_0 = torch.ones((Z, X), dtype=torch.float64, device=device)
63
+ matrix_theta_torch = [theta_0.clone()] # Clone pour éviter les références
64
+ I_reconMatrix = [theta_0.cpu().numpy()]
65
+ # Facteur de normalisation
66
+ normalization_factor = A_matrix_torch.sum(dim=(0, 3))
67
+ normalization_factor_flat = normalization_factor.reshape(-1)
68
+ # Construction de la matrice d'adjacence
69
+ adj_index, adj_values = _build_adjacency_sparse(Z, X, device=device, dtype=torch.float64)
70
+ # Description pour la barre de progression
71
+ description = f"AOT-BioMaps -- Bayesian Reconstruction Tomography: DE PIERRO (Sparse QUADRATIC β:{beta:.4f}, σ:{sigma:.4f}) ---- {tumor_str} TUMOR ---- processing on single GPU no.{torch.cuda.current_device()}"
72
+ # Configuration pour la sauvegarde des itérations
73
+ saved_indices = [0]
74
+
75
+ # Calculate save indices
76
+ if numIterations <= max_saves:
77
+ save_indices = list(range(numIterations))
78
+ else:
79
+ step = numIterations // max_saves
80
+ save_indices = list(range(0, numIterations, step))
81
+ if save_indices[-1] != numIterations - 1:
82
+ save_indices.append(numIterations - 1)
83
+
84
+ # Boucle principale MAP-EM
85
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
86
+ for it in iterator:
87
+ theta_p = matrix_theta_torch[-1]
88
+ theta_p_flat = theta_p.reshape(-1)
89
+ # Étape 1 : Projection avant
90
+ q_flat = A_flat @ theta_p_flat
91
+ q_flat = q_flat + torch.finfo(torch.float64).tiny # Évite la division par zéro
92
+ # Étape 2 : Estimation de l'erreur
93
+ e_flat = y_flat / q_flat
94
+ # Étape 3 : Rétroprojection de l'erreur
95
+ c_flat = A_flat.T @ e_flat
96
+ # Étape 4 : Mise à jour multiplicative (EM)
97
+ theta_EM_p_flat = theta_p_flat * c_flat
98
+ # Étape 5 : Calcul de W_j et gamma_j
99
+ W_j = scatter(adj_values, adj_index[0], dim=0, dim_size=J, reduce='sum') * (1.0 / (sigma**2))
100
+ theta_k = theta_p_flat[adj_index[1]]
101
+ weighted_theta_k = theta_k * adj_values
102
+ gamma_j = theta_p_flat * W_j + scatter(weighted_theta_k, adj_index[0], dim=0, dim_size=J, reduce='sum')
103
+ # Étape 6 : Mise à jour de De Pierro (résolution quadratique)
104
+ A_coeff = 2 * beta * W_j
105
+ B = -beta * gamma_j + normalization_factor_flat
106
+ C = -theta_EM_p_flat
107
+ discriminant = B**2 - 4 * A_coeff * C
108
+ discriminant = torch.clamp(discriminant, min=0)
109
+ theta_p_plus_1_flat = (-B + torch.sqrt(discriminant)) / (2 * A_coeff + torch.finfo(torch.float64).tiny)
110
+ theta_p_plus_1_flat = torch.clamp(theta_p_plus_1_flat, min=0)
111
+ # Étape 7 : Mise à jour de theta
112
+ theta_next = theta_p_plus_1_flat.reshape(Z, X)
113
+ matrix_theta_torch.append(theta_next) # Ajoute la nouvelle itération
114
+ # Sauvegarde conditionnelle
115
+ if isSavingEachIteration and it in save_indices:
116
+ I_reconMatrix.append(theta_next.cpu().numpy())
117
+ saved_indices.append(it)
118
+ # Libération mémoire partielle (optionnel, à ajuster selon besoin)
119
+ del theta_p_flat, q_flat, e_flat, c_flat, theta_EM_p_flat, theta_p_plus_1_flat
120
+ torch.cuda.empty_cache()
121
+
122
+ # Libération finale des tenseurs GPU
123
+ del A_matrix_torch, y_torch, A_flat, y_flat, normalization_factor, normalization_factor_flat
124
+ torch.cuda.empty_cache()
125
+ # Retour du résultat
126
+ if isSavingEachIteration:
127
+ return I_reconMatrix, saved_indices
128
+ else:
129
+ return matrix_theta_torch[-1].cpu().numpy(), None
130
+
131
+ def _DEPIERRO_CPU(SMatrix, y, numIterations, beta, sigma, isSavingEachIteration, tumor_str, device, max_saves, show_logs=True):
132
+ try:
133
+ if beta is None or sigma is None:
134
+ raise ValueError("Depierro95 optimizer requires beta and sigma parameters.")
135
+
136
+ A_matrix = np.array(SMatrix, dtype=np.float32)
137
+ y_array = np.array(y, dtype=np.float32)
138
+ T, Z, X, N = SMatrix.shape
139
+ J = Z * X
140
+ A_flat = A_matrix.transpose(0, 3, 1, 2).reshape(T * N, Z * X)
141
+ y_flat = y_array.reshape(-1)
142
+ theta_0 = np.ones((Z, X), dtype=np.float32)
143
+ matrix_theta = [theta_0]
144
+ I_reconMatrix = [theta_0.copy()]
145
+ saved_indices = [0]
146
+ normalization_factor = A_matrix.sum(axis=(0, 3))
147
+ normalization_factor_flat = normalization_factor.reshape(-1)
148
+ adj_index, adj_values = _build_adjacency_sparse(Z, X)
149
+
150
+ # Calculate save indices
151
+ if numIterations <= max_saves:
152
+ save_indices = list(range(numIterations))
153
+ else:
154
+ step = numIterations // max_saves
155
+ save_indices = list(range(0, numIterations, step))
156
+ if save_indices[-1] != numIterations - 1:
157
+ save_indices.append(numIterations - 1)
158
+
159
+ description = f"AOT-BioMaps -- Bayesian Reconstruction Tomography: DE PIERRO (Sparse QUADRATIC β:{beta:.4f}, σ:{sigma:.4f}) ---- {tumor_str} TUMOR ---- processing on single CPU ----"
160
+
161
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
162
+ for it in iterator:
163
+ theta_p = matrix_theta[-1]
164
+ theta_p_flat = theta_p.reshape(-1)
165
+ q_flat = np.dot(A_flat, theta_p_flat)
166
+ e_flat = y_flat / (q_flat + np.finfo(np.float32).tiny)
167
+ c_flat = np.dot(A_flat.T, e_flat)
168
+ theta_EM_p_flat = theta_p_flat * c_flat
169
+ alpha_j = normalization_factor_flat
170
+ W_j = np.bincount(adj_index[0], weights=adj_values, minlength=J) * (1.0 / sigma**2)
171
+ theta_k = theta_p_flat[adj_index[1]]
172
+ weighted_theta_k = theta_k * adj_values
173
+ gamma_j = theta_p_flat * W_j + np.bincount(adj_index[0], weights=weighted_theta_k, minlength=J)
174
+ A = 2 * beta * W_j
175
+ B = -beta * gamma_j + alpha_j
176
+ C = -theta_EM_p_flat
177
+ theta_p_plus_1_flat = (-B + np.sqrt(B**2 - 4 * A * C)) / (2 * A + np.finfo(np.float32).tiny)
178
+ theta_p_plus_1_flat = np.clip(theta_p_plus_1_flat, a_min=0, a_max=None)
179
+ theta_next = theta_p_plus_1_flat.reshape(Z, X)
180
+ matrix_theta[-1] = theta_next
181
+ if isSavingEachIteration and it in save_indices:
182
+ I_reconMatrix.append(theta_next.copy())
183
+ saved_indices.append(it)
184
+
185
+ if isSavingEachIteration:
186
+ return I_reconMatrix, saved_indices
187
+ else:
188
+ return I_reconMatrix[-1], None
189
+ except Exception as e:
190
+ print(f"An error occurred in _DEPIERRO_CPU: {e}")
191
+ return None, None
@@ -0,0 +1,106 @@
1
+ from AOT_biomaps.Config import config
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import trange
5
+ from AOT_biomaps.AOT_Recon.ReconTools import calculate_memory_requirement, check_gpu_memory
6
+
7
+ def LS(
8
+ SMatrix,
9
+ y,
10
+ numIterations=5000,
11
+ alpha=1e-3,
12
+ isSavingEachIteration=True,
13
+ withTumor=True,
14
+ device=None,
15
+ max_saves=5000,
16
+ show_logs=True
17
+ ):
18
+ """
19
+ Least Squares reconstruction using Projected Gradient Descent (PGD) with non-negativity constraint.
20
+ Currently only implements the stable GPU version.
21
+ """
22
+ tumor_str = "WITH" if withTumor else "WITHOUT"
23
+ # Force GPU usage for now
24
+ if device is None:
25
+ if torch.cuda.is_available() and check_gpu_memory(config.select_best_gpu(), calculate_memory_requirement(SMatrix, y), show_logs=show_logs):
26
+ raise RuntimeError("CUDA is required for this implementation.")
27
+ device = torch.device(f"cuda:{config.select_best_gpu()}")
28
+ else:
29
+ if device.type != "cuda":
30
+ raise RuntimeError("Only GPU implementation is available for now.")
31
+ return _LS_GPU_stable(SMatrix, y, numIterations, alpha, isSavingEachIteration, tumor_str, max_saves, show_logs=show_logs)
32
+
33
+ def _LS_GPU_stable(SMatrix, y, numIterations, alpha, isSavingEachIteration, tumor_str, max_saves=5000, show_logs=True):
34
+ """
35
+ Stable GPU implementation of LS using projected gradient descent with diagonal preconditioner.
36
+ """
37
+ device = torch.device(f"cuda:{config.select_best_gpu()}")
38
+ T, Z, X, N = SMatrix.shape
39
+ ZX = Z * X
40
+ TN = T * N
41
+ # 1. Conversion et normalisation
42
+ A_flat = torch.from_numpy(SMatrix).to(device=device, dtype=torch.float32).permute(0, 3, 1, 2).reshape(TN, ZX)
43
+ y_flat = torch.from_numpy(y).to(device=device, dtype=torch.float32).reshape(TN)
44
+ norm_A = A_flat.max()
45
+ norm_y = y_flat.max()
46
+ A_flat.div_(norm_A + 1e-8)
47
+ y_flat.div_(norm_y + 1e-8)
48
+ # 2. Initialisation
49
+ lambda_k = torch.zeros(ZX, device=device)
50
+ lambda_history = [] if isSavingEachIteration else None
51
+ saved_indices = [] # Pour stocker les indices des itérations sauvegardées
52
+
53
+ # Calculate save indices
54
+ if numIterations <= max_saves:
55
+ save_indices = list(range(numIterations))
56
+ else:
57
+ step = numIterations // max_saves
58
+ save_indices = list(range(0, numIterations, step))
59
+ if save_indices[-1] != numIterations - 1:
60
+ save_indices.append(numIterations - 1)
61
+
62
+ # Préconditionneur diagonal
63
+ diag_AAT = torch.sum(A_flat ** 2, dim=0)
64
+ M_inv = 1.0 / torch.clamp(diag_AAT, min=1e-6)
65
+ # Pré-allocation des tenseurs
66
+ r_k = torch.empty_like(y_flat)
67
+ AT_r = torch.empty(ZX, device=device)
68
+ description = f"AOT-BioMaps -- Stable LS Reconstruction ---- {tumor_str} TUMOR ---- GPU {torch.cuda.current_device()}"
69
+
70
+ iterator = trange(numIterations, desc=description) if show_logs else range(numIterations)
71
+ for it in iterator:
72
+ # Calcul du résidu (inplace)
73
+ torch.matmul(A_flat, lambda_k, out=r_k)
74
+ r_k = y_flat - r_k
75
+ if isSavingEachIteration and it in save_indices:
76
+ lambda_history.append(lambda_k.clone().reshape(Z, X) * (norm_y / norm_A))
77
+ saved_indices.append(it)
78
+
79
+ # Gradient préconditionné (inplace)
80
+ torch.matmul(A_flat.T, r_k, out=AT_r)
81
+ AT_r *= M_inv
82
+ # Mise à jour avec pas fixe et projection (inplace)
83
+ lambda_k.add_(AT_r, alpha=alpha)
84
+ lambda_k.clamp_(min=0)
85
+
86
+ # 3. Dénormalisation
87
+ lambda_final = lambda_k.reshape(Z, X) * (norm_y / norm_A)
88
+ # Free memory
89
+ del A_flat, y_flat, r_k, AT_r
90
+ torch.cuda.empty_cache()
91
+ if isSavingEachIteration:
92
+ return [t.cpu().numpy() for t in lambda_history], saved_indices
93
+ else:
94
+ return lambda_final.cpu().numpy(), None
95
+
96
+ def _LS_GPU_opti(*args, **kwargs):
97
+ raise NotImplementedError("Only _LS_GPU_stable is implemented for now.")
98
+
99
+ def _LS_GPU_multi(*args, **kwargs):
100
+ raise NotImplementedError("Only _LS_GPU_stable is implemented for now.")
101
+
102
+ def _LS_CPU_opti(*args, **kwargs):
103
+ raise NotImplementedError("Only _LS_GPU_stable is implemented for now.")
104
+
105
+ def _LS_CPU_basic(*args, **kwargs):
106
+ raise NotImplementedError("Only _LS_GPU_stable is implemented for now.")