AMS-BP 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- AMS_BP/__init__.py +13 -0
- AMS_BP/cells/__init__.py +5 -0
- AMS_BP/cells/base_cell.py +55 -0
- AMS_BP/cells/rectangular_cell.py +82 -0
- AMS_BP/cells/rod_cell.py +98 -0
- AMS_BP/cells/spherical_cell.py +74 -0
- AMS_BP/configio/__init__.py +0 -0
- AMS_BP/configio/configmodels.py +93 -0
- AMS_BP/configio/convertconfig.py +910 -0
- AMS_BP/configio/experiments.py +121 -0
- AMS_BP/configio/saving.py +32 -0
- AMS_BP/metadata/__init__.py +0 -0
- AMS_BP/metadata/metadata.py +87 -0
- AMS_BP/motion/__init__.py +4 -0
- AMS_BP/motion/condensate_movement.py +356 -0
- AMS_BP/motion/movement/__init__.py +10 -0
- AMS_BP/motion/movement/boundary_conditions.py +75 -0
- AMS_BP/motion/movement/fbm_BP.py +244 -0
- AMS_BP/motion/track_gen.py +541 -0
- AMS_BP/optics/__init__.py +0 -0
- AMS_BP/optics/camera/__init__.py +4 -0
- AMS_BP/optics/camera/detectors.py +320 -0
- AMS_BP/optics/camera/quantum_eff.py +66 -0
- AMS_BP/optics/filters/__init__.py +17 -0
- AMS_BP/optics/filters/channels/__init__.py +0 -0
- AMS_BP/optics/filters/channels/channelschema.py +27 -0
- AMS_BP/optics/filters/filters.py +184 -0
- AMS_BP/optics/lasers/__init__.py +28 -0
- AMS_BP/optics/lasers/laser_profiles.py +691 -0
- AMS_BP/optics/psf/__init__.py +7 -0
- AMS_BP/optics/psf/psf_engine.py +215 -0
- AMS_BP/photophysics/__init__.py +0 -0
- AMS_BP/photophysics/photon_physics.py +181 -0
- AMS_BP/photophysics/state_kinetics.py +146 -0
- AMS_BP/probabilityfuncs/__init__.py +0 -0
- AMS_BP/probabilityfuncs/markov_chain.py +143 -0
- AMS_BP/probabilityfuncs/probability_functions.py +350 -0
- AMS_BP/run_cell_simulation.py +217 -0
- AMS_BP/sample/__init__.py +0 -0
- AMS_BP/sample/flurophores/__init__.py +16 -0
- AMS_BP/sample/flurophores/flurophore_schema.py +290 -0
- AMS_BP/sample/sim_sampleplane.py +334 -0
- AMS_BP/sim_config.toml +418 -0
- AMS_BP/sim_microscopy.py +453 -0
- AMS_BP/utils/__init__.py +0 -0
- AMS_BP/utils/constants.py +11 -0
- AMS_BP/utils/decorators.py +227 -0
- AMS_BP/utils/errors.py +37 -0
- AMS_BP/utils/maskMaker.py +12 -0
- AMS_BP/utils/util_functions.py +319 -0
- ams_bp-0.0.2.dist-info/METADATA +173 -0
- ams_bp-0.0.2.dist-info/RECORD +55 -0
- ams_bp-0.0.2.dist-info/WHEEL +4 -0
- ams_bp-0.0.2.dist-info/entry_points.txt +2 -0
- ams_bp-0.0.2.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,215 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from functools import cache, cached_property, lru_cache
|
3
|
+
from typing import Literal, Optional, Tuple
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from numpy.typing import NDArray
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass(frozen=True)
|
10
|
+
class PSFParameters:
|
11
|
+
"""Parameters for PSF (Point Spread Function) generation.
|
12
|
+
|
13
|
+
Attributes:
|
14
|
+
wavelength: Light wavelength in nanometers
|
15
|
+
numerical_aperture: Numerical aperture of the optical system
|
16
|
+
pixel_size: Size of pixels in micrometers
|
17
|
+
z_step: Axial step size in micrometers
|
18
|
+
refractive_index: Refractive index of the medium (default: 1.0 for air)
|
19
|
+
"""
|
20
|
+
|
21
|
+
wavelength: float
|
22
|
+
numerical_aperture: float
|
23
|
+
pixel_size: float
|
24
|
+
z_step: float
|
25
|
+
refractive_index: float = 1.0
|
26
|
+
|
27
|
+
# def __post_init__(self) -> None:
|
28
|
+
# """Validate parameters after initialization."""
|
29
|
+
# if any(
|
30
|
+
# param <= 0
|
31
|
+
# for param in (
|
32
|
+
# self.wavelength,
|
33
|
+
# self.numerical_aperture,
|
34
|
+
# self.pixel_size,
|
35
|
+
# self.z_step,
|
36
|
+
# self.refractive_index,
|
37
|
+
# )
|
38
|
+
# ):
|
39
|
+
# raise ValueError("All parameters must be positive numbers")
|
40
|
+
# if self.numerical_aperture >= self.refractive_index:
|
41
|
+
# raise ValueError("Numerical aperture must be less than refractive index")
|
42
|
+
|
43
|
+
@cached_property
|
44
|
+
def wavelength_um(self) -> float:
|
45
|
+
"""Wavelength in micrometers."""
|
46
|
+
return self.wavelength / 1000.0
|
47
|
+
|
48
|
+
|
49
|
+
class PSFEngine:
|
50
|
+
"""Engine for generating various microscope Point Spread Functions.
|
51
|
+
|
52
|
+
This class implements calculations for both 2D and 3D Point Spread Functions
|
53
|
+
using Gaussian approximations.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(self, params: PSFParameters):
|
57
|
+
"""Initialize PSF engine with given parameters."""
|
58
|
+
self.params = params
|
59
|
+
self._initialize_calculations()
|
60
|
+
|
61
|
+
def _initialize_calculations(self) -> None:
|
62
|
+
"""Initialize commonly used calculations."""
|
63
|
+
self._sigma_xy = _calculate_sigma_xy(
|
64
|
+
self.params.wavelength_um, self.params.numerical_aperture
|
65
|
+
)
|
66
|
+
self._sigma_z = _calculate_sigma_z(
|
67
|
+
self.params.wavelength_um,
|
68
|
+
self.params.numerical_aperture,
|
69
|
+
self.params.refractive_index,
|
70
|
+
)
|
71
|
+
self._psf_size = calculate_psf_size(
|
72
|
+
sigma_xy=self._sigma_xy,
|
73
|
+
pixel_size=self.params.pixel_size,
|
74
|
+
sigma_z=self._sigma_z,
|
75
|
+
)
|
76
|
+
self._grid_xy = _generate_grid(self._psf_size, self.params.pixel_size)
|
77
|
+
|
78
|
+
# Pre-calculate normalized sigma values
|
79
|
+
self._norm_sigma_xy = self._sigma_xy / 2.355
|
80
|
+
self._norm_sigma_z = self._sigma_z / 2.355
|
81
|
+
|
82
|
+
@lru_cache(maxsize=128)
|
83
|
+
def psf_z(self, z_val: float) -> NDArray[np.float64]:
|
84
|
+
"""Generate z=z_val Gaussian approximation of PSF.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
z_val: Z-position in micrometers
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
2D array containing the PSF at given z position
|
91
|
+
"""
|
92
|
+
x, y = self._grid_xy
|
93
|
+
|
94
|
+
# Vectorized calculation
|
95
|
+
r_squared = (x / self._norm_sigma_xy) ** 2 + (y / self._norm_sigma_xy) ** 2
|
96
|
+
z_term = (z_val / self._norm_sigma_z) ** 2
|
97
|
+
return np.exp(-0.5 * (r_squared + z_term))
|
98
|
+
|
99
|
+
@lru_cache(maxsize=128)
|
100
|
+
def psf_z_xy0(self, z_val: float) -> float:
|
101
|
+
"""Generate z=z_val Gaussian approximation of PSF with x=y=0.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
z_val: Z-position in micrometers
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
PSF value at x=y=0 and given z position
|
108
|
+
"""
|
109
|
+
return np.exp(-0.5 * (z_val / self._norm_sigma_z) ** 2)
|
110
|
+
|
111
|
+
@cache
|
112
|
+
def _3d_normalization_A(
|
113
|
+
self, sigma_z: float, sigma_x: float, sigma_y: float
|
114
|
+
) -> float:
|
115
|
+
return 1.0 / (((2.0 * np.pi) ** (3.0 / 2.0)) * sigma_x * sigma_y * sigma_z)
|
116
|
+
|
117
|
+
@cache
|
118
|
+
def _2d_normalization_A(self, sigma_x: float, sigma_y: float) -> float:
|
119
|
+
return 1.0 / ((2.0 * np.pi) * sigma_x * sigma_y)
|
120
|
+
|
121
|
+
@staticmethod
|
122
|
+
def normalize_psf(
|
123
|
+
psf: NDArray[np.float64], mode: Literal["sum", "max", "energy"] = "sum"
|
124
|
+
) -> NDArray[np.float64]:
|
125
|
+
"""Normalize PSF with different schemes.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
psf: Input PSF array
|
129
|
+
mode: Normalization mode
|
130
|
+
- 'sum': Normalize so sum equals 1 (energy conservation)
|
131
|
+
- 'max': Normalize so maximum equals 1
|
132
|
+
- 'energy': Normalize so squared sum equals 1
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
Normalized PSF array
|
136
|
+
|
137
|
+
Raises:
|
138
|
+
ValueError: If unknown normalization mode is specified
|
139
|
+
"""
|
140
|
+
if not np.any(psf): # Check if array is all zeros
|
141
|
+
return psf
|
142
|
+
|
143
|
+
normalizers = {
|
144
|
+
"sum": np.sum,
|
145
|
+
"max": np.max,
|
146
|
+
"energy": lambda x: np.sqrt(np.sum(x**2)),
|
147
|
+
}
|
148
|
+
|
149
|
+
try:
|
150
|
+
normalizer = normalizers[mode]
|
151
|
+
return psf / normalizer(psf)
|
152
|
+
except KeyError:
|
153
|
+
raise ValueError(
|
154
|
+
f"Unknown normalization mode: {mode}. Valid modes: {list(normalizers.keys())}"
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
@cache
|
159
|
+
def _calculate_sigma_xy(wavelength_um: float, numerical_aperture: float) -> float:
|
160
|
+
"""Calculate lateral sigma value."""
|
161
|
+
return 0.61 * wavelength_um / numerical_aperture
|
162
|
+
|
163
|
+
|
164
|
+
@cache
|
165
|
+
def _calculate_sigma_z(
|
166
|
+
wavelength_um: float, numerical_aperture: float, refractive_index: float
|
167
|
+
) -> float:
|
168
|
+
"""Calculate axial sigma value."""
|
169
|
+
return 2.0 * wavelength_um * refractive_index / (numerical_aperture**2)
|
170
|
+
|
171
|
+
|
172
|
+
@cache
|
173
|
+
def _generate_grid(
|
174
|
+
size: Tuple[int, int], pixel_size: float
|
175
|
+
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
|
176
|
+
"""Generate coordinate grids for PSF calculation.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
size: Tuple of (height, width) for the grid
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
Tuple of x and y coordinate arrays
|
183
|
+
"""
|
184
|
+
y, x = np.ogrid[: size[0], : size[1]]
|
185
|
+
center_y, center_x = [(s - 1) / 2 for s in size]
|
186
|
+
y = (y - center_y) * pixel_size
|
187
|
+
x = (x - center_x) * pixel_size
|
188
|
+
return x, y
|
189
|
+
|
190
|
+
|
191
|
+
@cache
|
192
|
+
def calculate_psf_size(
|
193
|
+
sigma_xy: float, pixel_size: float, sigma_z: float, z_size: Optional[int] = None
|
194
|
+
) -> Tuple[int, ...]:
|
195
|
+
"""Calculate appropriate PSF size based on physical parameters.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
z_size: Optional number of z-planes for 3D PSF
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
Tuple of dimensions (z,y,x) or (y,x) for the PSF calculation
|
202
|
+
"""
|
203
|
+
# Calculate radius to capture important features (2x Airy radius)
|
204
|
+
r_psf = 2 * sigma_xy
|
205
|
+
|
206
|
+
# Convert to pixels and ensure odd number
|
207
|
+
pixels_xy = int(np.ceil(r_psf / pixel_size))
|
208
|
+
pixels_xy += (pixels_xy + 1) % 2
|
209
|
+
|
210
|
+
if z_size is not None:
|
211
|
+
pixels_z = int(np.ceil(2 * sigma_z / z_size))
|
212
|
+
pixels_z += (pixels_z + 1) % 2
|
213
|
+
return (pixels_z, pixels_xy, pixels_xy)
|
214
|
+
|
215
|
+
return (pixels_xy, pixels_xy)
|
File without changes
|
@@ -0,0 +1,181 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Callable, List, Optional, Tuple
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
from ..optics.camera.detectors import photon_noise
|
7
|
+
from ..optics.camera.quantum_eff import QuantumEfficiency
|
8
|
+
from ..optics.filters.filters import FilterSpectrum
|
9
|
+
from ..optics.psf.psf_engine import PSFEngine
|
10
|
+
from ..sample.flurophores.flurophore_schema import (
|
11
|
+
SpectralData,
|
12
|
+
WavelengthDependentProperty,
|
13
|
+
)
|
14
|
+
from ..utils.constants import H_C_COM
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class AbsorptionBase:
|
19
|
+
excitation_spectrum: SpectralData # wl in nm, relative intensity
|
20
|
+
intensity_incident: WavelengthDependentProperty # wl in nm, intensity in W/um^2
|
21
|
+
absorb_cross_section_spectrum: (
|
22
|
+
WavelengthDependentProperty # wl in nm, cross section in cm^2
|
23
|
+
)
|
24
|
+
|
25
|
+
def __post_init__(self):
|
26
|
+
self.flux_density_precursor_lambda = self._calc_flux_density_precursor()
|
27
|
+
|
28
|
+
def _calc_flux_density_precursor(self) -> WavelengthDependentProperty:
|
29
|
+
"""Per wavelength of incoming light W/cm^2 (intensity), find the quantity W/cm over the excitation_spectrum provided"""
|
30
|
+
wavelengths = []
|
31
|
+
ex_flux_density_lambda = []
|
32
|
+
for i in range(len(self.intensity_incident.wavelengths)):
|
33
|
+
intensity = self.intensity_incident.values[i]
|
34
|
+
wavelength = self.intensity_incident.wavelengths[i]
|
35
|
+
ex_spectrum = self.excitation_spectrum.get_value(wavelength)
|
36
|
+
ex_flux_density_lambda.append(intensity * wavelength * ex_spectrum)
|
37
|
+
wavelengths.append(wavelength)
|
38
|
+
return WavelengthDependentProperty(
|
39
|
+
wavelengths=wavelengths, values=ex_flux_density_lambda
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class AbsorptionPhysics(AbsorptionBase):
|
45
|
+
fluorescent_lifetime_inverse: float
|
46
|
+
|
47
|
+
def saturation_rate(self, rate: float, max_rate: float) -> float:
|
48
|
+
return np.clip(rate, 0, max_rate)
|
49
|
+
|
50
|
+
def absorbed_photon_rate(self) -> float:
|
51
|
+
"""Calculate the rate of incident photons"""
|
52
|
+
if self.flux_density_precursor_lambda is None:
|
53
|
+
raise ValueError("Flux density not calculated")
|
54
|
+
|
55
|
+
photon_rate_lambda = 0 # adding up all the wavelength based intensity rates
|
56
|
+
for i in range(len(self.flux_density_precursor_lambda.wavelengths)):
|
57
|
+
cross_section = self.absorb_cross_section_spectrum.values[i]
|
58
|
+
int_inverse_seconds_i = (
|
59
|
+
cross_section
|
60
|
+
* self.flux_density_precursor_lambda.values[i]
|
61
|
+
* H_C_COM
|
62
|
+
* 1e-1
|
63
|
+
)
|
64
|
+
|
65
|
+
photon_rate_lambda += int_inverse_seconds_i
|
66
|
+
return self.saturation_rate(
|
67
|
+
photon_rate_lambda,
|
68
|
+
self.fluorescent_lifetime_inverse,
|
69
|
+
) # 1/s, 10^-1 combined all conversion factors
|
70
|
+
|
71
|
+
|
72
|
+
@dataclass
|
73
|
+
class PhotoStateSwitchPhysics(AbsorptionBase):
|
74
|
+
quantum_yeild: float # switching events per photon absorbed
|
75
|
+
|
76
|
+
|
77
|
+
@dataclass
|
78
|
+
class EmissionPhysics:
|
79
|
+
emission_spectrum: SpectralData # wl in nm, normalied intensity
|
80
|
+
quantum_yield: WavelengthDependentProperty
|
81
|
+
transmission_filter: FilterSpectrum
|
82
|
+
|
83
|
+
def __post_init__(self):
|
84
|
+
# normalize emission spectrum
|
85
|
+
emission_spectrum_sum = sum(self.emission_spectrum.values)
|
86
|
+
self.emission_spectrum = SpectralData(
|
87
|
+
wavelengths=self.emission_spectrum.wavelengths,
|
88
|
+
intensities=[
|
89
|
+
val / emission_spectrum_sum for val in self.emission_spectrum.values
|
90
|
+
],
|
91
|
+
)
|
92
|
+
|
93
|
+
def emission_photon_rate(
|
94
|
+
self,
|
95
|
+
total_absorbed_rate: float, # 1/s
|
96
|
+
) -> WavelengthDependentProperty:
|
97
|
+
"""Calculate the rate of emitted photons (1/s)
|
98
|
+
|
99
|
+
Parameters:
|
100
|
+
total_absorbed_rate: float
|
101
|
+
"""
|
102
|
+
|
103
|
+
wavelengths = []
|
104
|
+
emission_rate_lambda = []
|
105
|
+
for i in range(len(self.emission_spectrum.wavelengths)):
|
106
|
+
wavelengths.append(self.emission_spectrum.wavelengths[i])
|
107
|
+
emission_rate_lambda.append(
|
108
|
+
total_absorbed_rate
|
109
|
+
* self.quantum_yield.values[i]
|
110
|
+
* self.emission_spectrum.values[i]
|
111
|
+
)
|
112
|
+
|
113
|
+
return WavelengthDependentProperty(
|
114
|
+
wavelengths=wavelengths, values=emission_rate_lambda
|
115
|
+
)
|
116
|
+
|
117
|
+
def transmission_photon_rate(
|
118
|
+
self, emission_photon_rate_lambda: WavelengthDependentProperty
|
119
|
+
) -> WavelengthDependentProperty:
|
120
|
+
"""Calculate the rate of transmitted photons (1/s)
|
121
|
+
|
122
|
+
Parameters:
|
123
|
+
emission_photon_rate_lambda: WavelengthDependentProperty
|
124
|
+
"""
|
125
|
+
wavelengths = []
|
126
|
+
transmission_rate_lambda = []
|
127
|
+
for i in range(len(emission_photon_rate_lambda.wavelengths)):
|
128
|
+
wavelengths.append(emission_photon_rate_lambda.wavelengths[i])
|
129
|
+
transmission_rate_lambda.append(
|
130
|
+
emission_photon_rate_lambda.values[i]
|
131
|
+
* self.transmission_filter.find_transmission(
|
132
|
+
emission_photon_rate_lambda.wavelengths[i]
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
136
|
+
return WavelengthDependentProperty(
|
137
|
+
wavelengths=wavelengths, values=transmission_rate_lambda
|
138
|
+
)
|
139
|
+
|
140
|
+
|
141
|
+
@dataclass
|
142
|
+
class incident_photons:
|
143
|
+
transmission_photon_rate: WavelengthDependentProperty
|
144
|
+
quantumEff: QuantumEfficiency
|
145
|
+
psf: Callable[[float | int, Optional[float | int]], PSFEngine]
|
146
|
+
position: Tuple[float, float, float]
|
147
|
+
|
148
|
+
def __post_init__(self):
|
149
|
+
self.generator = []
|
150
|
+
for i in range(len(self.transmission_photon_rate.wavelengths)):
|
151
|
+
if self.transmission_photon_rate.values[i] > 0:
|
152
|
+
self.generator.append(
|
153
|
+
self.psf(
|
154
|
+
self.transmission_photon_rate.wavelengths[i], self.position[2]
|
155
|
+
)
|
156
|
+
)
|
157
|
+
else:
|
158
|
+
self.generator.append(0)
|
159
|
+
|
160
|
+
def incident_photons_calc(self, dt: float) -> Tuple[float, List]:
|
161
|
+
photons = 0
|
162
|
+
psf_hold = []
|
163
|
+
for i in range(len(self.transmission_photon_rate.wavelengths)):
|
164
|
+
if self.transmission_photon_rate.values[i] > 0:
|
165
|
+
qe_lam = self.quantumEff.get_qe(
|
166
|
+
self.transmission_photon_rate.wavelengths[i]
|
167
|
+
)
|
168
|
+
photons_n = self.transmission_photon_rate.values[i] * dt
|
169
|
+
photons += photons_n
|
170
|
+
psf_gen = (
|
171
|
+
self.generator[i].normalize_psf(
|
172
|
+
self.generator[i].psf_z(z_val=self.position[2]),
|
173
|
+
mode="sum",
|
174
|
+
)
|
175
|
+
* self.generator[i].psf_z_xy0(z_val=self.position[2])
|
176
|
+
* photons_n
|
177
|
+
)
|
178
|
+
|
179
|
+
psf_hold.append(photon_noise(psf_gen) * qe_lam)
|
180
|
+
|
181
|
+
return photons, psf_hold
|
@@ -0,0 +1,146 @@
|
|
1
|
+
from functools import cache
|
2
|
+
from typing import Callable, Optional, Sequence, Tuple
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from ..sample.flurophores.flurophore_schema import State, StateType
|
8
|
+
from ..sample.sim_sampleplane import FluorescentObject
|
9
|
+
|
10
|
+
|
11
|
+
class ErnoMsg(BaseModel):
|
12
|
+
success: bool
|
13
|
+
erno_time: Optional[float | None] = None
|
14
|
+
erno_end_state: Optional[State | None] = None
|
15
|
+
|
16
|
+
|
17
|
+
class StateTransitionCalculator:
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
flurophoreobj: FluorescentObject,
|
21
|
+
time_duration: int | float,
|
22
|
+
current_global_time: int,
|
23
|
+
laser_intensity_generator: Callable,
|
24
|
+
) -> None:
|
25
|
+
self.flurophoreobj = flurophoreobj
|
26
|
+
self.time_duration = time_duration # seconds
|
27
|
+
self.current_global_time = current_global_time # ms (oversample motion time)
|
28
|
+
self.laser_intensity_generator = laser_intensity_generator
|
29
|
+
self.fluorescent_state_history = {} # {fluorescent.state.name : [delta time (seconds), laser_intensites], ...}
|
30
|
+
|
31
|
+
def __call__(
|
32
|
+
self,
|
33
|
+
) -> Tuple[dict, State, ErnoMsg]:
|
34
|
+
state, erno = self.MCMC()
|
35
|
+
return self.fluorescent_state_history, state, erno
|
36
|
+
|
37
|
+
def _initialize_state_hist(self, time_pos: int, time_laser: float) -> dict:
|
38
|
+
laser_intensities = self.laser_intensity_generator(
|
39
|
+
florPos=self.flurophoreobj.position_history[time_pos],
|
40
|
+
time=time_laser,
|
41
|
+
)
|
42
|
+
for i in self.flurophoreobj.fluorophore.states.values():
|
43
|
+
if i.state_type == StateType.FLUORESCENT:
|
44
|
+
self.fluorescent_state_history[i.name] = [0, laser_intensities]
|
45
|
+
return laser_intensities
|
46
|
+
|
47
|
+
def MCMC(self) -> Tuple[State, ErnoMsg]:
|
48
|
+
time = 0
|
49
|
+
transitions = self.flurophoreobj.state_history[self.current_global_time][2]
|
50
|
+
final_state_name = transitions[0].from_state
|
51
|
+
laser_intensities = self._initialize_state_hist(self.current_global_time, time)
|
52
|
+
|
53
|
+
while time < self.time_duration:
|
54
|
+
stateTransitionMatrixR = [
|
55
|
+
sum(
|
56
|
+
state_transitions.rate()(laser["wavelength"], laser["intensity"])
|
57
|
+
for laser in laser_intensities.values()
|
58
|
+
)
|
59
|
+
for state_transitions in transitions
|
60
|
+
] # 1/s
|
61
|
+
if not stateTransitionMatrixR:
|
62
|
+
break
|
63
|
+
if sum(stateTransitionMatrixR) == 0:
|
64
|
+
break
|
65
|
+
|
66
|
+
# print(final_state_name)
|
67
|
+
new_time, state_indx = ssa_step(
|
68
|
+
stateTransitionMatrixR
|
69
|
+
) # seconds, index on transitions
|
70
|
+
|
71
|
+
state_name = transitions[state_indx].to_state
|
72
|
+
if new_time > self.time_duration:
|
73
|
+
erno_time = new_time - time
|
74
|
+
new_time = self.time_duration - time
|
75
|
+
erno = ErnoMsg(
|
76
|
+
success=False,
|
77
|
+
erno_time=erno_time,
|
78
|
+
erno_end_state=self.flurophoreobj.fluorophore.states[state_name],
|
79
|
+
)
|
80
|
+
if (
|
81
|
+
self.flurophoreobj.fluorophore.states[final_state_name].state_type
|
82
|
+
== StateType.FLUORESCENT
|
83
|
+
):
|
84
|
+
# print("I glow inside")
|
85
|
+
self.fluorescent_state_history[
|
86
|
+
self.flurophoreobj.fluorophore.states[final_state_name].name
|
87
|
+
][0] += new_time
|
88
|
+
return self.flurophoreobj.fluorophore.states[final_state_name], erno
|
89
|
+
|
90
|
+
# print(new_time)
|
91
|
+
|
92
|
+
if (
|
93
|
+
self.flurophoreobj.fluorophore.states[final_state_name].state_type
|
94
|
+
== StateType.FLUORESCENT
|
95
|
+
):
|
96
|
+
# print("I glow")
|
97
|
+
self.fluorescent_state_history[
|
98
|
+
self.flurophoreobj.fluorophore.states[final_state_name].name
|
99
|
+
][0] += new_time
|
100
|
+
final_state_name = state_name
|
101
|
+
transitions = self._find_transitions(state_name)
|
102
|
+
time += new_time
|
103
|
+
# find state
|
104
|
+
return self.flurophoreobj.fluorophore.states[final_state_name], ErnoMsg(
|
105
|
+
success=True
|
106
|
+
)
|
107
|
+
|
108
|
+
@cache
|
109
|
+
def _find_transitions(self, statename: str) -> list:
|
110
|
+
return [
|
111
|
+
stateTrans
|
112
|
+
for stateTrans in self.flurophoreobj.fluorophore.transitions.values()
|
113
|
+
if stateTrans.from_state == statename
|
114
|
+
]
|
115
|
+
|
116
|
+
|
117
|
+
def ssa_step(reaction_rates: Sequence[float | int]) -> tuple[float, int]:
|
118
|
+
"""
|
119
|
+
Perform one step of the SSA simulation.
|
120
|
+
|
121
|
+
Parameters:
|
122
|
+
- reaction_rates: List of reaction rates [k1, k2, ...]
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
- dt: Time step to advance
|
126
|
+
- next_event: Index of the next reaction (0-based)
|
127
|
+
"""
|
128
|
+
# Calculate propensities
|
129
|
+
propensities = np.array(reaction_rates)
|
130
|
+
total_propensity = np.sum(propensities)
|
131
|
+
|
132
|
+
if total_propensity == 0:
|
133
|
+
raise ValueError("Total propensity is zero; no reactions can occur.")
|
134
|
+
|
135
|
+
# Draw two random numbers
|
136
|
+
r1, r2 = np.random.uniform(0, 1, size=2)
|
137
|
+
|
138
|
+
# Compute time step
|
139
|
+
dt = -np.log(r1) / total_propensity
|
140
|
+
|
141
|
+
# Determine the next reaction
|
142
|
+
cumulative_propensities = np.cumsum(propensities)
|
143
|
+
threshold = r2 * total_propensity
|
144
|
+
next_event = np.searchsorted(cumulative_propensities, threshold)
|
145
|
+
|
146
|
+
return dt, next_event
|
File without changes
|
@@ -0,0 +1,143 @@
|
|
1
|
+
from functools import cache
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from scipy.linalg import fractional_matrix_power
|
5
|
+
|
6
|
+
|
7
|
+
def MCMC_state_selection(
|
8
|
+
initial_state_index: int,
|
9
|
+
transition_matrix: np.ndarray,
|
10
|
+
possible_states: np.ndarray,
|
11
|
+
n: int,
|
12
|
+
) -> np.ndarray:
|
13
|
+
"""
|
14
|
+
Markov Chain Monte Carlo (MCMC) state selection.
|
15
|
+
|
16
|
+
This function simulates state transitions using a Markov Chain Monte Carlo method.
|
17
|
+
It selects the next state based on the current state and a transition matrix over `n` iterations.
|
18
|
+
The probability in the transition matrix is the probability of switching to a new state in the "time" step from n-1 -> n.
|
19
|
+
|
20
|
+
Parameters:
|
21
|
+
-----------
|
22
|
+
initial_state_index : int
|
23
|
+
The index of the initial state in the possible states.
|
24
|
+
transition_matrix : np.ndarray
|
25
|
+
A square matrix representing the transition probabilities between states.
|
26
|
+
possible_states : np.ndarray
|
27
|
+
An array of possible states for the system.
|
28
|
+
n : int
|
29
|
+
The number of iterations to perform.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
--------
|
33
|
+
np.ndarray
|
34
|
+
An array of selected states at each iteration.
|
35
|
+
"""
|
36
|
+
# initialize the state selection
|
37
|
+
state_selection = np.zeros(n)
|
38
|
+
# initialize the current state
|
39
|
+
current_state = possible_states[initial_state_index]
|
40
|
+
current_state_index = initial_state_index
|
41
|
+
# iterate through the number of iterations
|
42
|
+
for i in range(n):
|
43
|
+
# find the probability of switching to each state
|
44
|
+
state_probability = transition_matrix[current_state_index]
|
45
|
+
# find the next state
|
46
|
+
next_state_index = np.random.choice(
|
47
|
+
np.arange(len(possible_states)), p=state_probability
|
48
|
+
)
|
49
|
+
next_state = possible_states[next_state_index]
|
50
|
+
# update the current state
|
51
|
+
current_state = next_state
|
52
|
+
current_state_index = next_state_index
|
53
|
+
state_selection[i] = current_state
|
54
|
+
return state_selection
|
55
|
+
|
56
|
+
|
57
|
+
def MCMC_state_selection_rate(
|
58
|
+
initial_state_index: int,
|
59
|
+
transition_matrix: np.ndarray, # in rate, (1/s) s= seconds
|
60
|
+
possible_states: np.ndarray,
|
61
|
+
n: int,
|
62
|
+
time_unit: int, # amount of time (ms) in one n; ms = milliseconds
|
63
|
+
):
|
64
|
+
# convert transition_matrix to probability
|
65
|
+
# divide elementwise to convert 1/s -> 1/ms */1000
|
66
|
+
transition_matrix = transition_matrix * (1.0 / 1000.0)
|
67
|
+
# convert to prob
|
68
|
+
for i in range(len(transition_matrix)):
|
69
|
+
for j in range(len(transition_matrix[i])):
|
70
|
+
transition_matrix[i][j] = rate_to_probability(
|
71
|
+
transition_matrix[i][j], time_unit
|
72
|
+
)
|
73
|
+
|
74
|
+
assert np.sum(transition_matrix, axis=0) == 1.0
|
75
|
+
|
76
|
+
# apply "MCMC_state_selection
|
77
|
+
return MCMC_state_selection(
|
78
|
+
initial_state_index=initial_state_index,
|
79
|
+
transition_matrix=transition_matrix,
|
80
|
+
possible_states=possible_states,
|
81
|
+
n=n,
|
82
|
+
)
|
83
|
+
|
84
|
+
|
85
|
+
# convert from rate (1/s) to probability (0-1)
|
86
|
+
@cache
|
87
|
+
def rate_to_probability(rate: float, dt: float) -> float:
|
88
|
+
"""Convert from rate (1/s) to probability (0-1)
|
89
|
+
|
90
|
+
Parameters:
|
91
|
+
-----------
|
92
|
+
rate : float
|
93
|
+
The rate (1/s)
|
94
|
+
dt : float
|
95
|
+
The time step (s) for the probability calculation
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
--------
|
99
|
+
float
|
100
|
+
The probability (0-1)
|
101
|
+
"""
|
102
|
+
return 1 - np.exp(-rate * dt)
|
103
|
+
|
104
|
+
|
105
|
+
# convert from probability (0-1) to rate (1/s)
|
106
|
+
@cache
|
107
|
+
def probability_to_rate(probability: float, dt: float) -> float:
|
108
|
+
"""Convert from probability (0-1) to rate (1/s)
|
109
|
+
|
110
|
+
Parameters:
|
111
|
+
-----------
|
112
|
+
probability : float
|
113
|
+
The probability (0-1)
|
114
|
+
dt : float
|
115
|
+
The time step (s) for the probability calculation
|
116
|
+
"""
|
117
|
+
return -np.log(1 - probability) / dt
|
118
|
+
|
119
|
+
|
120
|
+
# fractional probability util
|
121
|
+
def change_prob_time(
|
122
|
+
probability: np.ndarray | float, dt: float, dt_prime: float
|
123
|
+
) -> np.ndarray:
|
124
|
+
"""Change the probability defined for dt to dt'
|
125
|
+
|
126
|
+
Parameters:
|
127
|
+
-----------
|
128
|
+
probability : np.ndarray | float
|
129
|
+
The probability (0-1)
|
130
|
+
dt : float
|
131
|
+
The time step (s) for the probability calculation
|
132
|
+
dt_prime : float
|
133
|
+
The new time step (s) for the probability calculation
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
--------
|
137
|
+
np.ndarray | float
|
138
|
+
The probability (0-1)
|
139
|
+
"""
|
140
|
+
if isinstance(probability, np.ndarray):
|
141
|
+
return fractional_matrix_power(probability, dt_prime / dt)
|
142
|
+
else:
|
143
|
+
return probability ** (dt_prime / dt)
|