AMS-BP 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. AMS_BP/__init__.py +13 -0
  2. AMS_BP/cells/__init__.py +5 -0
  3. AMS_BP/cells/base_cell.py +55 -0
  4. AMS_BP/cells/rectangular_cell.py +82 -0
  5. AMS_BP/cells/rod_cell.py +98 -0
  6. AMS_BP/cells/spherical_cell.py +74 -0
  7. AMS_BP/configio/__init__.py +0 -0
  8. AMS_BP/configio/configmodels.py +93 -0
  9. AMS_BP/configio/convertconfig.py +910 -0
  10. AMS_BP/configio/experiments.py +121 -0
  11. AMS_BP/configio/saving.py +32 -0
  12. AMS_BP/metadata/__init__.py +0 -0
  13. AMS_BP/metadata/metadata.py +87 -0
  14. AMS_BP/motion/__init__.py +4 -0
  15. AMS_BP/motion/condensate_movement.py +356 -0
  16. AMS_BP/motion/movement/__init__.py +10 -0
  17. AMS_BP/motion/movement/boundary_conditions.py +75 -0
  18. AMS_BP/motion/movement/fbm_BP.py +244 -0
  19. AMS_BP/motion/track_gen.py +541 -0
  20. AMS_BP/optics/__init__.py +0 -0
  21. AMS_BP/optics/camera/__init__.py +4 -0
  22. AMS_BP/optics/camera/detectors.py +320 -0
  23. AMS_BP/optics/camera/quantum_eff.py +66 -0
  24. AMS_BP/optics/filters/__init__.py +17 -0
  25. AMS_BP/optics/filters/channels/__init__.py +0 -0
  26. AMS_BP/optics/filters/channels/channelschema.py +27 -0
  27. AMS_BP/optics/filters/filters.py +184 -0
  28. AMS_BP/optics/lasers/__init__.py +28 -0
  29. AMS_BP/optics/lasers/laser_profiles.py +691 -0
  30. AMS_BP/optics/psf/__init__.py +7 -0
  31. AMS_BP/optics/psf/psf_engine.py +215 -0
  32. AMS_BP/photophysics/__init__.py +0 -0
  33. AMS_BP/photophysics/photon_physics.py +181 -0
  34. AMS_BP/photophysics/state_kinetics.py +146 -0
  35. AMS_BP/probabilityfuncs/__init__.py +0 -0
  36. AMS_BP/probabilityfuncs/markov_chain.py +143 -0
  37. AMS_BP/probabilityfuncs/probability_functions.py +350 -0
  38. AMS_BP/run_cell_simulation.py +217 -0
  39. AMS_BP/sample/__init__.py +0 -0
  40. AMS_BP/sample/flurophores/__init__.py +16 -0
  41. AMS_BP/sample/flurophores/flurophore_schema.py +290 -0
  42. AMS_BP/sample/sim_sampleplane.py +334 -0
  43. AMS_BP/sim_config.toml +418 -0
  44. AMS_BP/sim_microscopy.py +453 -0
  45. AMS_BP/utils/__init__.py +0 -0
  46. AMS_BP/utils/constants.py +11 -0
  47. AMS_BP/utils/decorators.py +227 -0
  48. AMS_BP/utils/errors.py +37 -0
  49. AMS_BP/utils/maskMaker.py +12 -0
  50. AMS_BP/utils/util_functions.py +319 -0
  51. ams_bp-0.0.2.dist-info/METADATA +173 -0
  52. ams_bp-0.0.2.dist-info/RECORD +55 -0
  53. ams_bp-0.0.2.dist-info/WHEEL +4 -0
  54. ams_bp-0.0.2.dist-info/entry_points.txt +2 -0
  55. ams_bp-0.0.2.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,215 @@
1
+ from dataclasses import dataclass
2
+ from functools import cache, cached_property, lru_cache
3
+ from typing import Literal, Optional, Tuple
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class PSFParameters:
11
+ """Parameters for PSF (Point Spread Function) generation.
12
+
13
+ Attributes:
14
+ wavelength: Light wavelength in nanometers
15
+ numerical_aperture: Numerical aperture of the optical system
16
+ pixel_size: Size of pixels in micrometers
17
+ z_step: Axial step size in micrometers
18
+ refractive_index: Refractive index of the medium (default: 1.0 for air)
19
+ """
20
+
21
+ wavelength: float
22
+ numerical_aperture: float
23
+ pixel_size: float
24
+ z_step: float
25
+ refractive_index: float = 1.0
26
+
27
+ # def __post_init__(self) -> None:
28
+ # """Validate parameters after initialization."""
29
+ # if any(
30
+ # param <= 0
31
+ # for param in (
32
+ # self.wavelength,
33
+ # self.numerical_aperture,
34
+ # self.pixel_size,
35
+ # self.z_step,
36
+ # self.refractive_index,
37
+ # )
38
+ # ):
39
+ # raise ValueError("All parameters must be positive numbers")
40
+ # if self.numerical_aperture >= self.refractive_index:
41
+ # raise ValueError("Numerical aperture must be less than refractive index")
42
+
43
+ @cached_property
44
+ def wavelength_um(self) -> float:
45
+ """Wavelength in micrometers."""
46
+ return self.wavelength / 1000.0
47
+
48
+
49
+ class PSFEngine:
50
+ """Engine for generating various microscope Point Spread Functions.
51
+
52
+ This class implements calculations for both 2D and 3D Point Spread Functions
53
+ using Gaussian approximations.
54
+ """
55
+
56
+ def __init__(self, params: PSFParameters):
57
+ """Initialize PSF engine with given parameters."""
58
+ self.params = params
59
+ self._initialize_calculations()
60
+
61
+ def _initialize_calculations(self) -> None:
62
+ """Initialize commonly used calculations."""
63
+ self._sigma_xy = _calculate_sigma_xy(
64
+ self.params.wavelength_um, self.params.numerical_aperture
65
+ )
66
+ self._sigma_z = _calculate_sigma_z(
67
+ self.params.wavelength_um,
68
+ self.params.numerical_aperture,
69
+ self.params.refractive_index,
70
+ )
71
+ self._psf_size = calculate_psf_size(
72
+ sigma_xy=self._sigma_xy,
73
+ pixel_size=self.params.pixel_size,
74
+ sigma_z=self._sigma_z,
75
+ )
76
+ self._grid_xy = _generate_grid(self._psf_size, self.params.pixel_size)
77
+
78
+ # Pre-calculate normalized sigma values
79
+ self._norm_sigma_xy = self._sigma_xy / 2.355
80
+ self._norm_sigma_z = self._sigma_z / 2.355
81
+
82
+ @lru_cache(maxsize=128)
83
+ def psf_z(self, z_val: float) -> NDArray[np.float64]:
84
+ """Generate z=z_val Gaussian approximation of PSF.
85
+
86
+ Args:
87
+ z_val: Z-position in micrometers
88
+
89
+ Returns:
90
+ 2D array containing the PSF at given z position
91
+ """
92
+ x, y = self._grid_xy
93
+
94
+ # Vectorized calculation
95
+ r_squared = (x / self._norm_sigma_xy) ** 2 + (y / self._norm_sigma_xy) ** 2
96
+ z_term = (z_val / self._norm_sigma_z) ** 2
97
+ return np.exp(-0.5 * (r_squared + z_term))
98
+
99
+ @lru_cache(maxsize=128)
100
+ def psf_z_xy0(self, z_val: float) -> float:
101
+ """Generate z=z_val Gaussian approximation of PSF with x=y=0.
102
+
103
+ Args:
104
+ z_val: Z-position in micrometers
105
+
106
+ Returns:
107
+ PSF value at x=y=0 and given z position
108
+ """
109
+ return np.exp(-0.5 * (z_val / self._norm_sigma_z) ** 2)
110
+
111
+ @cache
112
+ def _3d_normalization_A(
113
+ self, sigma_z: float, sigma_x: float, sigma_y: float
114
+ ) -> float:
115
+ return 1.0 / (((2.0 * np.pi) ** (3.0 / 2.0)) * sigma_x * sigma_y * sigma_z)
116
+
117
+ @cache
118
+ def _2d_normalization_A(self, sigma_x: float, sigma_y: float) -> float:
119
+ return 1.0 / ((2.0 * np.pi) * sigma_x * sigma_y)
120
+
121
+ @staticmethod
122
+ def normalize_psf(
123
+ psf: NDArray[np.float64], mode: Literal["sum", "max", "energy"] = "sum"
124
+ ) -> NDArray[np.float64]:
125
+ """Normalize PSF with different schemes.
126
+
127
+ Args:
128
+ psf: Input PSF array
129
+ mode: Normalization mode
130
+ - 'sum': Normalize so sum equals 1 (energy conservation)
131
+ - 'max': Normalize so maximum equals 1
132
+ - 'energy': Normalize so squared sum equals 1
133
+
134
+ Returns:
135
+ Normalized PSF array
136
+
137
+ Raises:
138
+ ValueError: If unknown normalization mode is specified
139
+ """
140
+ if not np.any(psf): # Check if array is all zeros
141
+ return psf
142
+
143
+ normalizers = {
144
+ "sum": np.sum,
145
+ "max": np.max,
146
+ "energy": lambda x: np.sqrt(np.sum(x**2)),
147
+ }
148
+
149
+ try:
150
+ normalizer = normalizers[mode]
151
+ return psf / normalizer(psf)
152
+ except KeyError:
153
+ raise ValueError(
154
+ f"Unknown normalization mode: {mode}. Valid modes: {list(normalizers.keys())}"
155
+ )
156
+
157
+
158
+ @cache
159
+ def _calculate_sigma_xy(wavelength_um: float, numerical_aperture: float) -> float:
160
+ """Calculate lateral sigma value."""
161
+ return 0.61 * wavelength_um / numerical_aperture
162
+
163
+
164
+ @cache
165
+ def _calculate_sigma_z(
166
+ wavelength_um: float, numerical_aperture: float, refractive_index: float
167
+ ) -> float:
168
+ """Calculate axial sigma value."""
169
+ return 2.0 * wavelength_um * refractive_index / (numerical_aperture**2)
170
+
171
+
172
+ @cache
173
+ def _generate_grid(
174
+ size: Tuple[int, int], pixel_size: float
175
+ ) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
176
+ """Generate coordinate grids for PSF calculation.
177
+
178
+ Args:
179
+ size: Tuple of (height, width) for the grid
180
+
181
+ Returns:
182
+ Tuple of x and y coordinate arrays
183
+ """
184
+ y, x = np.ogrid[: size[0], : size[1]]
185
+ center_y, center_x = [(s - 1) / 2 for s in size]
186
+ y = (y - center_y) * pixel_size
187
+ x = (x - center_x) * pixel_size
188
+ return x, y
189
+
190
+
191
+ @cache
192
+ def calculate_psf_size(
193
+ sigma_xy: float, pixel_size: float, sigma_z: float, z_size: Optional[int] = None
194
+ ) -> Tuple[int, ...]:
195
+ """Calculate appropriate PSF size based on physical parameters.
196
+
197
+ Args:
198
+ z_size: Optional number of z-planes for 3D PSF
199
+
200
+ Returns:
201
+ Tuple of dimensions (z,y,x) or (y,x) for the PSF calculation
202
+ """
203
+ # Calculate radius to capture important features (2x Airy radius)
204
+ r_psf = 2 * sigma_xy
205
+
206
+ # Convert to pixels and ensure odd number
207
+ pixels_xy = int(np.ceil(r_psf / pixel_size))
208
+ pixels_xy += (pixels_xy + 1) % 2
209
+
210
+ if z_size is not None:
211
+ pixels_z = int(np.ceil(2 * sigma_z / z_size))
212
+ pixels_z += (pixels_z + 1) % 2
213
+ return (pixels_z, pixels_xy, pixels_xy)
214
+
215
+ return (pixels_xy, pixels_xy)
File without changes
@@ -0,0 +1,181 @@
1
+ from dataclasses import dataclass
2
+ from typing import Callable, List, Optional, Tuple
3
+
4
+ import numpy as np
5
+
6
+ from ..optics.camera.detectors import photon_noise
7
+ from ..optics.camera.quantum_eff import QuantumEfficiency
8
+ from ..optics.filters.filters import FilterSpectrum
9
+ from ..optics.psf.psf_engine import PSFEngine
10
+ from ..sample.flurophores.flurophore_schema import (
11
+ SpectralData,
12
+ WavelengthDependentProperty,
13
+ )
14
+ from ..utils.constants import H_C_COM
15
+
16
+
17
+ @dataclass
18
+ class AbsorptionBase:
19
+ excitation_spectrum: SpectralData # wl in nm, relative intensity
20
+ intensity_incident: WavelengthDependentProperty # wl in nm, intensity in W/um^2
21
+ absorb_cross_section_spectrum: (
22
+ WavelengthDependentProperty # wl in nm, cross section in cm^2
23
+ )
24
+
25
+ def __post_init__(self):
26
+ self.flux_density_precursor_lambda = self._calc_flux_density_precursor()
27
+
28
+ def _calc_flux_density_precursor(self) -> WavelengthDependentProperty:
29
+ """Per wavelength of incoming light W/cm^2 (intensity), find the quantity W/cm over the excitation_spectrum provided"""
30
+ wavelengths = []
31
+ ex_flux_density_lambda = []
32
+ for i in range(len(self.intensity_incident.wavelengths)):
33
+ intensity = self.intensity_incident.values[i]
34
+ wavelength = self.intensity_incident.wavelengths[i]
35
+ ex_spectrum = self.excitation_spectrum.get_value(wavelength)
36
+ ex_flux_density_lambda.append(intensity * wavelength * ex_spectrum)
37
+ wavelengths.append(wavelength)
38
+ return WavelengthDependentProperty(
39
+ wavelengths=wavelengths, values=ex_flux_density_lambda
40
+ )
41
+
42
+
43
+ @dataclass
44
+ class AbsorptionPhysics(AbsorptionBase):
45
+ fluorescent_lifetime_inverse: float
46
+
47
+ def saturation_rate(self, rate: float, max_rate: float) -> float:
48
+ return np.clip(rate, 0, max_rate)
49
+
50
+ def absorbed_photon_rate(self) -> float:
51
+ """Calculate the rate of incident photons"""
52
+ if self.flux_density_precursor_lambda is None:
53
+ raise ValueError("Flux density not calculated")
54
+
55
+ photon_rate_lambda = 0 # adding up all the wavelength based intensity rates
56
+ for i in range(len(self.flux_density_precursor_lambda.wavelengths)):
57
+ cross_section = self.absorb_cross_section_spectrum.values[i]
58
+ int_inverse_seconds_i = (
59
+ cross_section
60
+ * self.flux_density_precursor_lambda.values[i]
61
+ * H_C_COM
62
+ * 1e-1
63
+ )
64
+
65
+ photon_rate_lambda += int_inverse_seconds_i
66
+ return self.saturation_rate(
67
+ photon_rate_lambda,
68
+ self.fluorescent_lifetime_inverse,
69
+ ) # 1/s, 10^-1 combined all conversion factors
70
+
71
+
72
+ @dataclass
73
+ class PhotoStateSwitchPhysics(AbsorptionBase):
74
+ quantum_yeild: float # switching events per photon absorbed
75
+
76
+
77
+ @dataclass
78
+ class EmissionPhysics:
79
+ emission_spectrum: SpectralData # wl in nm, normalied intensity
80
+ quantum_yield: WavelengthDependentProperty
81
+ transmission_filter: FilterSpectrum
82
+
83
+ def __post_init__(self):
84
+ # normalize emission spectrum
85
+ emission_spectrum_sum = sum(self.emission_spectrum.values)
86
+ self.emission_spectrum = SpectralData(
87
+ wavelengths=self.emission_spectrum.wavelengths,
88
+ intensities=[
89
+ val / emission_spectrum_sum for val in self.emission_spectrum.values
90
+ ],
91
+ )
92
+
93
+ def emission_photon_rate(
94
+ self,
95
+ total_absorbed_rate: float, # 1/s
96
+ ) -> WavelengthDependentProperty:
97
+ """Calculate the rate of emitted photons (1/s)
98
+
99
+ Parameters:
100
+ total_absorbed_rate: float
101
+ """
102
+
103
+ wavelengths = []
104
+ emission_rate_lambda = []
105
+ for i in range(len(self.emission_spectrum.wavelengths)):
106
+ wavelengths.append(self.emission_spectrum.wavelengths[i])
107
+ emission_rate_lambda.append(
108
+ total_absorbed_rate
109
+ * self.quantum_yield.values[i]
110
+ * self.emission_spectrum.values[i]
111
+ )
112
+
113
+ return WavelengthDependentProperty(
114
+ wavelengths=wavelengths, values=emission_rate_lambda
115
+ )
116
+
117
+ def transmission_photon_rate(
118
+ self, emission_photon_rate_lambda: WavelengthDependentProperty
119
+ ) -> WavelengthDependentProperty:
120
+ """Calculate the rate of transmitted photons (1/s)
121
+
122
+ Parameters:
123
+ emission_photon_rate_lambda: WavelengthDependentProperty
124
+ """
125
+ wavelengths = []
126
+ transmission_rate_lambda = []
127
+ for i in range(len(emission_photon_rate_lambda.wavelengths)):
128
+ wavelengths.append(emission_photon_rate_lambda.wavelengths[i])
129
+ transmission_rate_lambda.append(
130
+ emission_photon_rate_lambda.values[i]
131
+ * self.transmission_filter.find_transmission(
132
+ emission_photon_rate_lambda.wavelengths[i]
133
+ )
134
+ )
135
+
136
+ return WavelengthDependentProperty(
137
+ wavelengths=wavelengths, values=transmission_rate_lambda
138
+ )
139
+
140
+
141
+ @dataclass
142
+ class incident_photons:
143
+ transmission_photon_rate: WavelengthDependentProperty
144
+ quantumEff: QuantumEfficiency
145
+ psf: Callable[[float | int, Optional[float | int]], PSFEngine]
146
+ position: Tuple[float, float, float]
147
+
148
+ def __post_init__(self):
149
+ self.generator = []
150
+ for i in range(len(self.transmission_photon_rate.wavelengths)):
151
+ if self.transmission_photon_rate.values[i] > 0:
152
+ self.generator.append(
153
+ self.psf(
154
+ self.transmission_photon_rate.wavelengths[i], self.position[2]
155
+ )
156
+ )
157
+ else:
158
+ self.generator.append(0)
159
+
160
+ def incident_photons_calc(self, dt: float) -> Tuple[float, List]:
161
+ photons = 0
162
+ psf_hold = []
163
+ for i in range(len(self.transmission_photon_rate.wavelengths)):
164
+ if self.transmission_photon_rate.values[i] > 0:
165
+ qe_lam = self.quantumEff.get_qe(
166
+ self.transmission_photon_rate.wavelengths[i]
167
+ )
168
+ photons_n = self.transmission_photon_rate.values[i] * dt
169
+ photons += photons_n
170
+ psf_gen = (
171
+ self.generator[i].normalize_psf(
172
+ self.generator[i].psf_z(z_val=self.position[2]),
173
+ mode="sum",
174
+ )
175
+ * self.generator[i].psf_z_xy0(z_val=self.position[2])
176
+ * photons_n
177
+ )
178
+
179
+ psf_hold.append(photon_noise(psf_gen) * qe_lam)
180
+
181
+ return photons, psf_hold
@@ -0,0 +1,146 @@
1
+ from functools import cache
2
+ from typing import Callable, Optional, Sequence, Tuple
3
+
4
+ import numpy as np
5
+ from pydantic import BaseModel
6
+
7
+ from ..sample.flurophores.flurophore_schema import State, StateType
8
+ from ..sample.sim_sampleplane import FluorescentObject
9
+
10
+
11
+ class ErnoMsg(BaseModel):
12
+ success: bool
13
+ erno_time: Optional[float | None] = None
14
+ erno_end_state: Optional[State | None] = None
15
+
16
+
17
+ class StateTransitionCalculator:
18
+ def __init__(
19
+ self,
20
+ flurophoreobj: FluorescentObject,
21
+ time_duration: int | float,
22
+ current_global_time: int,
23
+ laser_intensity_generator: Callable,
24
+ ) -> None:
25
+ self.flurophoreobj = flurophoreobj
26
+ self.time_duration = time_duration # seconds
27
+ self.current_global_time = current_global_time # ms (oversample motion time)
28
+ self.laser_intensity_generator = laser_intensity_generator
29
+ self.fluorescent_state_history = {} # {fluorescent.state.name : [delta time (seconds), laser_intensites], ...}
30
+
31
+ def __call__(
32
+ self,
33
+ ) -> Tuple[dict, State, ErnoMsg]:
34
+ state, erno = self.MCMC()
35
+ return self.fluorescent_state_history, state, erno
36
+
37
+ def _initialize_state_hist(self, time_pos: int, time_laser: float) -> dict:
38
+ laser_intensities = self.laser_intensity_generator(
39
+ florPos=self.flurophoreobj.position_history[time_pos],
40
+ time=time_laser,
41
+ )
42
+ for i in self.flurophoreobj.fluorophore.states.values():
43
+ if i.state_type == StateType.FLUORESCENT:
44
+ self.fluorescent_state_history[i.name] = [0, laser_intensities]
45
+ return laser_intensities
46
+
47
+ def MCMC(self) -> Tuple[State, ErnoMsg]:
48
+ time = 0
49
+ transitions = self.flurophoreobj.state_history[self.current_global_time][2]
50
+ final_state_name = transitions[0].from_state
51
+ laser_intensities = self._initialize_state_hist(self.current_global_time, time)
52
+
53
+ while time < self.time_duration:
54
+ stateTransitionMatrixR = [
55
+ sum(
56
+ state_transitions.rate()(laser["wavelength"], laser["intensity"])
57
+ for laser in laser_intensities.values()
58
+ )
59
+ for state_transitions in transitions
60
+ ] # 1/s
61
+ if not stateTransitionMatrixR:
62
+ break
63
+ if sum(stateTransitionMatrixR) == 0:
64
+ break
65
+
66
+ # print(final_state_name)
67
+ new_time, state_indx = ssa_step(
68
+ stateTransitionMatrixR
69
+ ) # seconds, index on transitions
70
+
71
+ state_name = transitions[state_indx].to_state
72
+ if new_time > self.time_duration:
73
+ erno_time = new_time - time
74
+ new_time = self.time_duration - time
75
+ erno = ErnoMsg(
76
+ success=False,
77
+ erno_time=erno_time,
78
+ erno_end_state=self.flurophoreobj.fluorophore.states[state_name],
79
+ )
80
+ if (
81
+ self.flurophoreobj.fluorophore.states[final_state_name].state_type
82
+ == StateType.FLUORESCENT
83
+ ):
84
+ # print("I glow inside")
85
+ self.fluorescent_state_history[
86
+ self.flurophoreobj.fluorophore.states[final_state_name].name
87
+ ][0] += new_time
88
+ return self.flurophoreobj.fluorophore.states[final_state_name], erno
89
+
90
+ # print(new_time)
91
+
92
+ if (
93
+ self.flurophoreobj.fluorophore.states[final_state_name].state_type
94
+ == StateType.FLUORESCENT
95
+ ):
96
+ # print("I glow")
97
+ self.fluorescent_state_history[
98
+ self.flurophoreobj.fluorophore.states[final_state_name].name
99
+ ][0] += new_time
100
+ final_state_name = state_name
101
+ transitions = self._find_transitions(state_name)
102
+ time += new_time
103
+ # find state
104
+ return self.flurophoreobj.fluorophore.states[final_state_name], ErnoMsg(
105
+ success=True
106
+ )
107
+
108
+ @cache
109
+ def _find_transitions(self, statename: str) -> list:
110
+ return [
111
+ stateTrans
112
+ for stateTrans in self.flurophoreobj.fluorophore.transitions.values()
113
+ if stateTrans.from_state == statename
114
+ ]
115
+
116
+
117
+ def ssa_step(reaction_rates: Sequence[float | int]) -> tuple[float, int]:
118
+ """
119
+ Perform one step of the SSA simulation.
120
+
121
+ Parameters:
122
+ - reaction_rates: List of reaction rates [k1, k2, ...]
123
+
124
+ Returns:
125
+ - dt: Time step to advance
126
+ - next_event: Index of the next reaction (0-based)
127
+ """
128
+ # Calculate propensities
129
+ propensities = np.array(reaction_rates)
130
+ total_propensity = np.sum(propensities)
131
+
132
+ if total_propensity == 0:
133
+ raise ValueError("Total propensity is zero; no reactions can occur.")
134
+
135
+ # Draw two random numbers
136
+ r1, r2 = np.random.uniform(0, 1, size=2)
137
+
138
+ # Compute time step
139
+ dt = -np.log(r1) / total_propensity
140
+
141
+ # Determine the next reaction
142
+ cumulative_propensities = np.cumsum(propensities)
143
+ threshold = r2 * total_propensity
144
+ next_event = np.searchsorted(cumulative_propensities, threshold)
145
+
146
+ return dt, next_event
File without changes
@@ -0,0 +1,143 @@
1
+ from functools import cache
2
+
3
+ import numpy as np
4
+ from scipy.linalg import fractional_matrix_power
5
+
6
+
7
+ def MCMC_state_selection(
8
+ initial_state_index: int,
9
+ transition_matrix: np.ndarray,
10
+ possible_states: np.ndarray,
11
+ n: int,
12
+ ) -> np.ndarray:
13
+ """
14
+ Markov Chain Monte Carlo (MCMC) state selection.
15
+
16
+ This function simulates state transitions using a Markov Chain Monte Carlo method.
17
+ It selects the next state based on the current state and a transition matrix over `n` iterations.
18
+ The probability in the transition matrix is the probability of switching to a new state in the "time" step from n-1 -> n.
19
+
20
+ Parameters:
21
+ -----------
22
+ initial_state_index : int
23
+ The index of the initial state in the possible states.
24
+ transition_matrix : np.ndarray
25
+ A square matrix representing the transition probabilities between states.
26
+ possible_states : np.ndarray
27
+ An array of possible states for the system.
28
+ n : int
29
+ The number of iterations to perform.
30
+
31
+ Returns:
32
+ --------
33
+ np.ndarray
34
+ An array of selected states at each iteration.
35
+ """
36
+ # initialize the state selection
37
+ state_selection = np.zeros(n)
38
+ # initialize the current state
39
+ current_state = possible_states[initial_state_index]
40
+ current_state_index = initial_state_index
41
+ # iterate through the number of iterations
42
+ for i in range(n):
43
+ # find the probability of switching to each state
44
+ state_probability = transition_matrix[current_state_index]
45
+ # find the next state
46
+ next_state_index = np.random.choice(
47
+ np.arange(len(possible_states)), p=state_probability
48
+ )
49
+ next_state = possible_states[next_state_index]
50
+ # update the current state
51
+ current_state = next_state
52
+ current_state_index = next_state_index
53
+ state_selection[i] = current_state
54
+ return state_selection
55
+
56
+
57
+ def MCMC_state_selection_rate(
58
+ initial_state_index: int,
59
+ transition_matrix: np.ndarray, # in rate, (1/s) s= seconds
60
+ possible_states: np.ndarray,
61
+ n: int,
62
+ time_unit: int, # amount of time (ms) in one n; ms = milliseconds
63
+ ):
64
+ # convert transition_matrix to probability
65
+ # divide elementwise to convert 1/s -> 1/ms */1000
66
+ transition_matrix = transition_matrix * (1.0 / 1000.0)
67
+ # convert to prob
68
+ for i in range(len(transition_matrix)):
69
+ for j in range(len(transition_matrix[i])):
70
+ transition_matrix[i][j] = rate_to_probability(
71
+ transition_matrix[i][j], time_unit
72
+ )
73
+
74
+ assert np.sum(transition_matrix, axis=0) == 1.0
75
+
76
+ # apply "MCMC_state_selection
77
+ return MCMC_state_selection(
78
+ initial_state_index=initial_state_index,
79
+ transition_matrix=transition_matrix,
80
+ possible_states=possible_states,
81
+ n=n,
82
+ )
83
+
84
+
85
+ # convert from rate (1/s) to probability (0-1)
86
+ @cache
87
+ def rate_to_probability(rate: float, dt: float) -> float:
88
+ """Convert from rate (1/s) to probability (0-1)
89
+
90
+ Parameters:
91
+ -----------
92
+ rate : float
93
+ The rate (1/s)
94
+ dt : float
95
+ The time step (s) for the probability calculation
96
+
97
+ Returns:
98
+ --------
99
+ float
100
+ The probability (0-1)
101
+ """
102
+ return 1 - np.exp(-rate * dt)
103
+
104
+
105
+ # convert from probability (0-1) to rate (1/s)
106
+ @cache
107
+ def probability_to_rate(probability: float, dt: float) -> float:
108
+ """Convert from probability (0-1) to rate (1/s)
109
+
110
+ Parameters:
111
+ -----------
112
+ probability : float
113
+ The probability (0-1)
114
+ dt : float
115
+ The time step (s) for the probability calculation
116
+ """
117
+ return -np.log(1 - probability) / dt
118
+
119
+
120
+ # fractional probability util
121
+ def change_prob_time(
122
+ probability: np.ndarray | float, dt: float, dt_prime: float
123
+ ) -> np.ndarray:
124
+ """Change the probability defined for dt to dt'
125
+
126
+ Parameters:
127
+ -----------
128
+ probability : np.ndarray | float
129
+ The probability (0-1)
130
+ dt : float
131
+ The time step (s) for the probability calculation
132
+ dt_prime : float
133
+ The new time step (s) for the probability calculation
134
+
135
+ Returns:
136
+ --------
137
+ np.ndarray | float
138
+ The probability (0-1)
139
+ """
140
+ if isinstance(probability, np.ndarray):
141
+ return fractional_matrix_power(probability, dt_prime / dt)
142
+ else:
143
+ return probability ** (dt_prime / dt)