FlowCyPy 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- FlowCyPy/__init__.py +15 -0
- FlowCyPy/_version.py +16 -0
- FlowCyPy/classifier.py +196 -0
- FlowCyPy/coupling_mechanism/__init__.py +4 -0
- FlowCyPy/coupling_mechanism/empirical.py +47 -0
- FlowCyPy/coupling_mechanism/mie.py +205 -0
- FlowCyPy/coupling_mechanism/rayleigh.py +115 -0
- FlowCyPy/coupling_mechanism/uniform.py +39 -0
- FlowCyPy/cytometer.py +198 -0
- FlowCyPy/detector.py +616 -0
- FlowCyPy/directories.py +36 -0
- FlowCyPy/distribution/__init__.py +16 -0
- FlowCyPy/distribution/base_class.py +59 -0
- FlowCyPy/distribution/delta.py +86 -0
- FlowCyPy/distribution/lognormal.py +94 -0
- FlowCyPy/distribution/normal.py +95 -0
- FlowCyPy/distribution/particle_size_distribution.py +110 -0
- FlowCyPy/distribution/uniform.py +96 -0
- FlowCyPy/distribution/weibull.py +80 -0
- FlowCyPy/event_correlator.py +244 -0
- FlowCyPy/flow_cell.py +122 -0
- FlowCyPy/helper.py +85 -0
- FlowCyPy/logger.py +322 -0
- FlowCyPy/noises.py +29 -0
- FlowCyPy/particle_count.py +102 -0
- FlowCyPy/peak_locator/__init__.py +4 -0
- FlowCyPy/peak_locator/base_class.py +163 -0
- FlowCyPy/peak_locator/basic.py +108 -0
- FlowCyPy/peak_locator/derivative.py +143 -0
- FlowCyPy/peak_locator/moving_average.py +114 -0
- FlowCyPy/physical_constant.py +19 -0
- FlowCyPy/plottings.py +270 -0
- FlowCyPy/population.py +239 -0
- FlowCyPy/populations_instances.py +49 -0
- FlowCyPy/report.py +236 -0
- FlowCyPy/scatterer.py +373 -0
- FlowCyPy/source.py +249 -0
- FlowCyPy/units.py +26 -0
- FlowCyPy/utils.py +191 -0
- FlowCyPy-0.5.0.dist-info/LICENSE +21 -0
- FlowCyPy-0.5.0.dist-info/METADATA +252 -0
- FlowCyPy-0.5.0.dist-info/RECORD +44 -0
- FlowCyPy-0.5.0.dist-info/WHEEL +5 -0
- FlowCyPy-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import numpy as np
|
|
5
|
+
from scipy.signal import find_peaks
|
|
6
|
+
from FlowCyPy.peak_locator.base_class import BasePeakLocator
|
|
7
|
+
from FlowCyPy.units import Quantity, volt
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class BasicPeakLocator(BasePeakLocator):
|
|
12
|
+
"""
|
|
13
|
+
A basic peak detector class that identifies peaks in a signal using a threshold-based method.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
threshold : Quantity, optional
|
|
18
|
+
The minimum height required for a peak to be considered significant. Default is `Quantity(0.1, volt)`.
|
|
19
|
+
min_peak_distance : Quantity, optional
|
|
20
|
+
The minimum distance between detected peaks. Default is `Quantity(0.1)`.
|
|
21
|
+
rel_height : float, optional
|
|
22
|
+
The relative height at which the peak width is measured. Default is `0.5`.
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
threshold: Quantity = Quantity(0.0, volt)
|
|
27
|
+
rel_height: float = 0.5
|
|
28
|
+
min_peak_distance: Quantity = None
|
|
29
|
+
|
|
30
|
+
def init_data(self, dataframe: pd.DataFrame) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Initialize the data for peak detection.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
dataframe : pd.DataFrame
|
|
37
|
+
A DataFrame containing the signal data with columns 'Signal' and 'Time'.
|
|
38
|
+
|
|
39
|
+
Raises
|
|
40
|
+
------
|
|
41
|
+
ValueError
|
|
42
|
+
If the DataFrame is missing required columns or is empty.
|
|
43
|
+
"""
|
|
44
|
+
super().init_data(dataframe)
|
|
45
|
+
|
|
46
|
+
if self.threshold is not None:
|
|
47
|
+
self.threshold = self.threshold.to(self.data['Signal'].values.units)
|
|
48
|
+
|
|
49
|
+
if self.min_peak_distance is not None:
|
|
50
|
+
self.min_peak_distance = self.min_peak_distance.to(self.data['Time'].values.units)
|
|
51
|
+
|
|
52
|
+
def _compute_algorithm_specific_features(self) -> None:
|
|
53
|
+
"""
|
|
54
|
+
Compute peaks based on the moving average algorithm.
|
|
55
|
+
"""
|
|
56
|
+
peak_indices = self._compute_peak_positions()
|
|
57
|
+
|
|
58
|
+
widths_samples, width_heights, left_ips, right_ips = self._compute_peak_widths(
|
|
59
|
+
peak_indices,
|
|
60
|
+
self.data['Signal'].values
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
return peak_indices, widths_samples, width_heights, left_ips, right_ips
|
|
64
|
+
|
|
65
|
+
def _compute_peak_positions(self) -> pd.DataFrame:
|
|
66
|
+
"""
|
|
67
|
+
Detects peaks in the signal and calculates their properties such as heights, widths, and areas.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
detector : pd.DataFrame
|
|
72
|
+
DataFrame with the signal data to detect peaks in.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
peak_times : Quantity
|
|
77
|
+
The times at which peaks occur.
|
|
78
|
+
heights : Quantity
|
|
79
|
+
The heights of the detected peaks.
|
|
80
|
+
widths : Quantity
|
|
81
|
+
The widths of the detected peaks.
|
|
82
|
+
areas : Quantity or None
|
|
83
|
+
The areas under each peak, if `compute_area` is True.
|
|
84
|
+
"""
|
|
85
|
+
# Find peaks in the difference signal
|
|
86
|
+
peak_indices, _ = find_peaks(
|
|
87
|
+
self.data['Signal'].values,
|
|
88
|
+
height=None if self.threshold is None else self.threshold.magnitude,
|
|
89
|
+
distance=None if self.min_peak_distance is None else int(np.ceil(self.min_peak_distance / self.dt))
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return peak_indices
|
|
93
|
+
|
|
94
|
+
def _add_custom_to_ax(self, time_unit: str | Quantity, signal_unit: str | Quantity, ax: plt.Axes = None) -> None:
|
|
95
|
+
"""
|
|
96
|
+
Add algorithm-specific elements to the plot.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
time_unit : str or Quantity
|
|
101
|
+
The unit for the time axis (e.g., 'microsecond').
|
|
102
|
+
signal_unit : str or Quantity
|
|
103
|
+
The unit for the signal axis (e.g., 'volt').
|
|
104
|
+
ax : matplotlib.axes.Axes
|
|
105
|
+
The Axes object to add elements to.
|
|
106
|
+
"""
|
|
107
|
+
# Plot the signal threshold line
|
|
108
|
+
ax.axhline(y=self.threshold.to(signal_unit).magnitude, color='black', linestyle='--', label='Threshold', lw=1)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from scipy.signal import find_peaks
|
|
5
|
+
from FlowCyPy.peak_locator.base_class import BasePeakLocator
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pint_pandas
|
|
8
|
+
from FlowCyPy.units import Quantity, microsecond
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class DerivativePeakLocator(BasePeakLocator):
|
|
13
|
+
"""
|
|
14
|
+
Detects peaks in a signal using a derivative-based algorithm.
|
|
15
|
+
A peak is identified when the derivative exceeds a defined threshold.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
derivative_threshold : Quantity, optional
|
|
20
|
+
The minimum derivative value required to detect a peak.
|
|
21
|
+
Default is `Quantity(0.1)`.
|
|
22
|
+
min_peak_distance : Quantity, optional
|
|
23
|
+
The minimum distance between detected peaks.
|
|
24
|
+
Default is `Quantity(0.1)`.
|
|
25
|
+
rel_height : float, optional
|
|
26
|
+
The relative height at which the peak width is measured. Default is `0.5` (half-height).
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
derivative_threshold: Quantity = Quantity(0.1, 'volt/microsecond')
|
|
31
|
+
min_peak_distance: Quantity = Quantity(0.1, microsecond)
|
|
32
|
+
rel_height: float = 0.5
|
|
33
|
+
|
|
34
|
+
def init_data(self, dataframe: pd.DataFrame) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Initialize the data for peak detection.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
dataframe : pd.DataFrame
|
|
41
|
+
A DataFrame containing the signal data with columns 'Signal' and 'Time'.
|
|
42
|
+
|
|
43
|
+
Raises
|
|
44
|
+
------
|
|
45
|
+
ValueError
|
|
46
|
+
If the DataFrame is missing required columns or is empty.
|
|
47
|
+
"""
|
|
48
|
+
super().init_data(dataframe)
|
|
49
|
+
|
|
50
|
+
if self.derivative_threshold is not None:
|
|
51
|
+
self.derivative_threshold = self.derivative_threshold.to(
|
|
52
|
+
self.data['Signal'].values.units / self.data['Time'].values.units
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if self.min_peak_distance is not None:
|
|
56
|
+
self.min_peak_distance = self.min_peak_distance.to(self.data['Time'].values.units)
|
|
57
|
+
|
|
58
|
+
def _compute_algorithm_specific_features(self) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Compute peaks based on the moving average algorithm.
|
|
61
|
+
"""
|
|
62
|
+
peak_indices = self._compute_peak_positions()
|
|
63
|
+
|
|
64
|
+
widths_samples, width_heights, left_ips, right_ips = self._compute_peak_widths(
|
|
65
|
+
peak_indices,
|
|
66
|
+
self.data['Signal'].values
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return peak_indices, widths_samples, width_heights, left_ips, right_ips
|
|
70
|
+
|
|
71
|
+
def _compute_peak_positions(self) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Compute peaks based on the derivative of the signal and refine their positions
|
|
74
|
+
to align with the actual maxima in the original signal.
|
|
75
|
+
"""
|
|
76
|
+
# Compute the derivative of the signal
|
|
77
|
+
derivative = np.gradient(
|
|
78
|
+
self.data['Signal'].values.quantity.magnitude,
|
|
79
|
+
self.data['Time'].values.quantity.magnitude
|
|
80
|
+
)
|
|
81
|
+
derivative = pint_pandas.PintArray(
|
|
82
|
+
derivative, dtype=self.data['Signal'].values.units / self.data['Time'].values.units
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Add the derivative to the DataFrame
|
|
86
|
+
self.data['Derivative'] = derivative
|
|
87
|
+
|
|
88
|
+
# Detect peaks in the derivative signal
|
|
89
|
+
derivative_peak_indices, _ = find_peaks(
|
|
90
|
+
self.data['Derivative'].values.quantity.magnitude,
|
|
91
|
+
height=self.derivative_threshold.magnitude,
|
|
92
|
+
prominence=0.1, # Adjust this if needed
|
|
93
|
+
plateau_size=True,
|
|
94
|
+
distance=None if self.min_peak_distance is None else int(np.ceil(self.min_peak_distance / self.dt))
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Refine detected peaks to align with maxima in the original signal
|
|
98
|
+
refined_peak_indices = []
|
|
99
|
+
refinement_window = 5 # Number of samples around each derivative peak to search for the max
|
|
100
|
+
|
|
101
|
+
for idx in derivative_peak_indices:
|
|
102
|
+
# Define search window boundaries
|
|
103
|
+
window_start = max(0, idx - refinement_window)
|
|
104
|
+
window_end = min(len(self.data) - 1, idx + refinement_window)
|
|
105
|
+
|
|
106
|
+
# Find the maximum in the original signal within the window
|
|
107
|
+
true_max_idx = window_start + np.argmax(self.data['Signal'].iloc[window_start:window_end])
|
|
108
|
+
refined_peak_indices.append(true_max_idx)
|
|
109
|
+
|
|
110
|
+
refined_peak_indices = np.unique(refined_peak_indices) # Remove duplicates
|
|
111
|
+
|
|
112
|
+
return refined_peak_indices
|
|
113
|
+
|
|
114
|
+
def _add_custom_to_ax(self, time_unit: str | Quantity, signal_unit: str | Quantity, ax: plt.Axes) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Add algorithm-specific elements to the plot.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
time_unit : str or Quantity
|
|
121
|
+
The unit for the time axis (e.g., 'microsecond').
|
|
122
|
+
signal_unit : str or Quantity
|
|
123
|
+
The unit for the signal axis (e.g., 'volt').
|
|
124
|
+
ax : matplotlib.axes.Axes
|
|
125
|
+
The Axes object to add elements to.
|
|
126
|
+
"""
|
|
127
|
+
# Plot the derivative
|
|
128
|
+
ax.plot(
|
|
129
|
+
self.data.Time,
|
|
130
|
+
self.data.Derivative,
|
|
131
|
+
linestyle='--',
|
|
132
|
+
color='C2',
|
|
133
|
+
label='Derivative'
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Plot the derivative threshold line
|
|
137
|
+
ax.axhline(
|
|
138
|
+
y=self.derivative_threshold.to(self.data['Derivative'].values.units).magnitude,
|
|
139
|
+
color='black',
|
|
140
|
+
linestyle='--',
|
|
141
|
+
label='Derivative Threshold',
|
|
142
|
+
lw=1
|
|
143
|
+
)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from scipy.signal import find_peaks
|
|
5
|
+
from FlowCyPy.peak_locator.base_class import BasePeakLocator
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pint_pandas
|
|
8
|
+
from FlowCyPy.units import Quantity, microsecond
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class MovingAverage(BasePeakLocator):
|
|
13
|
+
"""
|
|
14
|
+
Detects peaks in a signal using a moving average algorithm.
|
|
15
|
+
A peak is identified when the signal exceeds the moving average by a defined threshold.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
threshold : Quantity, optional
|
|
20
|
+
The minimum difference between the signal and its moving average required to detect a peak. Default is `Quantity(0.2)`.
|
|
21
|
+
window_size : Quantity, optional
|
|
22
|
+
The window size for calculating the moving average. Default is `Quantity(500)`.
|
|
23
|
+
min_peak_distance : Quantity, optional
|
|
24
|
+
The minimum distance between detected peaks. Default is `Quantity(0.1)`.
|
|
25
|
+
rel_height : float, optional
|
|
26
|
+
The relative height at which the peak width is measured. Default is `0.5` (half-height).
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
threshold: Quantity = None
|
|
31
|
+
window_size: Quantity = Quantity(1, microsecond)
|
|
32
|
+
min_peak_distance: Quantity = None
|
|
33
|
+
rel_height: float = 0.8
|
|
34
|
+
|
|
35
|
+
def init_data(self, dataframe: pd.DataFrame) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Initialize the data for peak detection.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
dataframe : pd.DataFrame
|
|
42
|
+
A DataFrame containing the signal data with columns 'Signal' and 'Time'.
|
|
43
|
+
|
|
44
|
+
Raises
|
|
45
|
+
------
|
|
46
|
+
ValueError
|
|
47
|
+
If the DataFrame is missing required columns or is empty.
|
|
48
|
+
"""
|
|
49
|
+
super().init_data(dataframe)
|
|
50
|
+
|
|
51
|
+
if self.threshold is not None:
|
|
52
|
+
self.threshold = self.threshold.to(self.data['Signal'].values.units)
|
|
53
|
+
|
|
54
|
+
self.window_size = self.window_size.to(self.data['Time'].values.units)
|
|
55
|
+
|
|
56
|
+
if self.min_peak_distance is not None:
|
|
57
|
+
self.min_peak_distance = self.min_peak_distance.to(self.data['Time'].values.units)
|
|
58
|
+
|
|
59
|
+
def _compute_algorithm_specific_features(self) -> None:
|
|
60
|
+
"""
|
|
61
|
+
Compute peaks based on the moving average algorithm.
|
|
62
|
+
"""
|
|
63
|
+
peak_indices = self._compute_peak_positions()
|
|
64
|
+
|
|
65
|
+
widths_samples, width_heights, left_ips, right_ips = self._compute_peak_widths(peak_indices, self.data['Difference'].values)
|
|
66
|
+
|
|
67
|
+
return peak_indices, widths_samples, width_heights, left_ips, right_ips
|
|
68
|
+
|
|
69
|
+
def _compute_peak_positions(self) -> None:
|
|
70
|
+
# Calculate moving average
|
|
71
|
+
window_size_samples = int(np.ceil(self.window_size / self.dt))
|
|
72
|
+
moving_avg = self.data['Signal'].rolling(window=window_size_samples, center=True, min_periods=1).mean()
|
|
73
|
+
|
|
74
|
+
# Reattach Pint units to the moving average
|
|
75
|
+
moving_avg = pint_pandas.PintArray(moving_avg, dtype=self.data['Signal'].values.units)
|
|
76
|
+
|
|
77
|
+
# Add the moving average to the DataFrame
|
|
78
|
+
self.data['MovingAverage'] = moving_avg
|
|
79
|
+
|
|
80
|
+
# Compute the difference signal
|
|
81
|
+
self.data['Difference'] = self.data['Signal'] - self.data['MovingAverage']
|
|
82
|
+
|
|
83
|
+
# Detect peaks
|
|
84
|
+
peak_indices, _ = find_peaks(
|
|
85
|
+
self.data['Difference'].values.quantity.magnitude,
|
|
86
|
+
height=None if self.threshold is None else self.threshold.magnitude,
|
|
87
|
+
distance=None if self.min_peak_distance is None else int(np.ceil(self.min_peak_distance / self.dt))
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return peak_indices
|
|
91
|
+
|
|
92
|
+
def _add_custom_to_ax(self, time_unit: str | Quantity, signal_unit: str | Quantity, ax: plt.Axes) -> None:
|
|
93
|
+
"""
|
|
94
|
+
Add algorithm-specific elements to the plot.
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
time_unit : str or Quantity
|
|
99
|
+
The unit for the time axis (e.g., 'microsecond').
|
|
100
|
+
signal_unit : str or Quantity
|
|
101
|
+
The unit for the signal axis (e.g., 'volt').
|
|
102
|
+
ax : matplotlib.axes.Axes
|
|
103
|
+
The Axes object to add elements to.
|
|
104
|
+
"""
|
|
105
|
+
ax.plot(
|
|
106
|
+
self.data.Time,
|
|
107
|
+
self.data.Difference,
|
|
108
|
+
linestyle='--',
|
|
109
|
+
color='C1',
|
|
110
|
+
label='MA-difference'
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Plot the signal threshold line
|
|
114
|
+
ax.axhline(y=self.threshold.to(signal_unit).magnitude, color='black', linestyle='--', label='Threshold', lw=1)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from FlowCyPy.units import meter, joule, second, farad, kelvin, coulomb
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
config_dict = dict(
|
|
6
|
+
arbitrary_types_allowed=True,
|
|
7
|
+
kw_only=True,
|
|
8
|
+
slots=True,
|
|
9
|
+
extra='forbid'
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PhysicalConstant:
|
|
14
|
+
h = 6.62607015e-34 * joule * second # Planck constant
|
|
15
|
+
c = 3e8 * meter / second # Speed of light
|
|
16
|
+
epsilon_0 = 8.8541878128e-12 * farad / meter # Permtivitty of vacuum
|
|
17
|
+
pi = np.pi # Pi, what else?
|
|
18
|
+
kb = 1.380649e-23 * joule / kelvin # Botlzmann constant
|
|
19
|
+
e = 1.602176634e-19 * coulomb # Electron charge
|
FlowCyPy/plottings.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
from typing import Optional, Union, Tuple
|
|
2
|
+
from MPSPlots.styles import mps
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MetricPlotter:
|
|
10
|
+
"""
|
|
11
|
+
A class for creating 2D density and scatter plots of scattering intensities from two detectors.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
coincidence_dataframe : pd.DataFrame
|
|
16
|
+
The dataframe containing the coincidence data, including detector and feature columns.
|
|
17
|
+
detector_names : tuple of str
|
|
18
|
+
A tuple containing the names of the two detectors (detector 0 and detector 1).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, coincidence_dataframe: pd.DataFrame, detector_names: Tuple[str, str]):
|
|
22
|
+
self.coincidence_dataframe = coincidence_dataframe.reset_index()
|
|
23
|
+
self.detector_names = detector_names
|
|
24
|
+
|
|
25
|
+
def _extract_feature_data(self, feature: str):
|
|
26
|
+
"""
|
|
27
|
+
Extracts and processes the feature data for the two detectors.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
feature : str
|
|
32
|
+
The feature to extract.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
Tuple[pd.Series, pd.Series, str, str]
|
|
37
|
+
Processed x_data, y_data, x_units, and y_units.
|
|
38
|
+
"""
|
|
39
|
+
name_0, name_1 = self.detector_names
|
|
40
|
+
x_data = self.coincidence_dataframe[(name_0, feature)]
|
|
41
|
+
y_data = self.coincidence_dataframe[(name_1, feature)]
|
|
42
|
+
|
|
43
|
+
x_units = x_data.max().to_compact().units
|
|
44
|
+
y_units = y_data.max().to_compact().units
|
|
45
|
+
|
|
46
|
+
x_data = x_data.pint.to(x_units)
|
|
47
|
+
y_data = y_data.pint.to(y_units)
|
|
48
|
+
|
|
49
|
+
return x_data, y_data, x_units, y_units
|
|
50
|
+
|
|
51
|
+
def _create_density_plot(
|
|
52
|
+
self,
|
|
53
|
+
x_data: pd.Series,
|
|
54
|
+
y_data: pd.Series,
|
|
55
|
+
bandwidth_adjust: float,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Creates a KDE density plot.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
x_data : pd.Series
|
|
63
|
+
The x-axis data.
|
|
64
|
+
y_data : pd.Series
|
|
65
|
+
The y-axis data.
|
|
66
|
+
bandwidth_adjust : float
|
|
67
|
+
Adjustment factor for the KDE bandwidth.
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
sns.JointGrid
|
|
72
|
+
The seaborn JointGrid object.
|
|
73
|
+
"""
|
|
74
|
+
return sns.jointplot(
|
|
75
|
+
data=self.coincidence_dataframe,
|
|
76
|
+
x=x_data,
|
|
77
|
+
y=y_data,
|
|
78
|
+
kind="kde",
|
|
79
|
+
alpha=0.8,
|
|
80
|
+
fill=True,
|
|
81
|
+
joint_kws={"alpha": 0.7, "bw_adjust": bandwidth_adjust},
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _add_scatterplot(
|
|
85
|
+
self,
|
|
86
|
+
g: sns.JointGrid,
|
|
87
|
+
x_data: pd.Series,
|
|
88
|
+
y_data: pd.Series,
|
|
89
|
+
color_palette: Optional[Union[str, dict]],
|
|
90
|
+
):
|
|
91
|
+
"""
|
|
92
|
+
Adds a scatterplot layer to the KDE density plot.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
g : sns.JointGrid
|
|
97
|
+
The seaborn JointGrid object to which the scatterplot is added.
|
|
98
|
+
x_data : pd.Series
|
|
99
|
+
The x-axis data.
|
|
100
|
+
y_data : pd.Series
|
|
101
|
+
The y-axis data.
|
|
102
|
+
color_palette : str or dict, optional
|
|
103
|
+
The color palette to use for the hue in the scatterplot.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
None
|
|
108
|
+
"""
|
|
109
|
+
sns.scatterplot(
|
|
110
|
+
data=self.coincidence_dataframe,
|
|
111
|
+
x=x_data,
|
|
112
|
+
y=y_data,
|
|
113
|
+
hue="Label",
|
|
114
|
+
palette=color_palette,
|
|
115
|
+
ax=g.ax_joint,
|
|
116
|
+
alpha=0.6,
|
|
117
|
+
zorder=1,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def _apply_axis_labels(
|
|
121
|
+
self,
|
|
122
|
+
g: sns.JointGrid,
|
|
123
|
+
feature: str,
|
|
124
|
+
x_units: str,
|
|
125
|
+
y_units: str,
|
|
126
|
+
):
|
|
127
|
+
"""
|
|
128
|
+
Sets the x and y labels with units on the plot.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
g : sns.JointGrid
|
|
133
|
+
The seaborn JointGrid object.
|
|
134
|
+
feature : str
|
|
135
|
+
The feature being plotted.
|
|
136
|
+
x_units : str
|
|
137
|
+
Units of the x-axis data.
|
|
138
|
+
y_units : str
|
|
139
|
+
Units of the y-axis data.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
None
|
|
144
|
+
"""
|
|
145
|
+
name_0, name_1 = self.detector_names
|
|
146
|
+
g.ax_joint.set_xlabel(f"{feature} : {name_0} [{x_units:P}]")
|
|
147
|
+
g.ax_joint.set_ylabel(f"{feature}: {name_1} [{y_units:P}]")
|
|
148
|
+
|
|
149
|
+
def _apply_axis_limits(
|
|
150
|
+
self,
|
|
151
|
+
g: sns.JointGrid,
|
|
152
|
+
x_limits: Optional[Tuple],
|
|
153
|
+
y_limits: Optional[Tuple],
|
|
154
|
+
x_units: str,
|
|
155
|
+
y_units: str):
|
|
156
|
+
"""
|
|
157
|
+
Sets the axis limits if specified.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
g : sns.JointGrid
|
|
162
|
+
The seaborn JointGrid object.
|
|
163
|
+
x_limits : tuple, optional
|
|
164
|
+
The x-axis limits (min, max), by default None.
|
|
165
|
+
y_limits : tuple, optional
|
|
166
|
+
The y-axis limits (min, max), by default None.
|
|
167
|
+
x_units : str
|
|
168
|
+
Units of the x-axis data.
|
|
169
|
+
y_units : str
|
|
170
|
+
Units of the y-axis data.
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
None
|
|
175
|
+
"""
|
|
176
|
+
if x_limits:
|
|
177
|
+
x0, x1 = x_limits
|
|
178
|
+
x0 = x0.to(x_units).magnitude
|
|
179
|
+
x1 = x1.to(x_units).magnitude
|
|
180
|
+
g.ax_joint.set_xlim(x0, x1)
|
|
181
|
+
|
|
182
|
+
if y_limits:
|
|
183
|
+
y0, y1 = y_limits
|
|
184
|
+
y0 = y0.to(y_units).magnitude
|
|
185
|
+
y1 = y1.to(y_units).magnitude
|
|
186
|
+
g.ax_joint.set_ylim(y0, y1)
|
|
187
|
+
|
|
188
|
+
def plot(
|
|
189
|
+
self,
|
|
190
|
+
feature: str,
|
|
191
|
+
show: bool = True,
|
|
192
|
+
log_plot: bool = True,
|
|
193
|
+
x_limits: Optional[Tuple] = None,
|
|
194
|
+
y_limits: Optional[Tuple] = None,
|
|
195
|
+
equal_axes: bool = False,
|
|
196
|
+
bandwidth_adjust: float = 1.0,
|
|
197
|
+
color_palette: Optional[Union[str, dict]] = 'tab10') -> None:
|
|
198
|
+
"""
|
|
199
|
+
Generates a 2D density plot of the scattering intensities, overlaid with individual peak heights.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
feature : str
|
|
204
|
+
The feature to plot (e.g., 'intensity').
|
|
205
|
+
show : bool, optional
|
|
206
|
+
Whether to display the plot immediately, by default True.
|
|
207
|
+
log_plot : bool, optional
|
|
208
|
+
Whether to use logarithmic scaling for the plot axes, by default True.
|
|
209
|
+
x_limits : tuple, optional
|
|
210
|
+
The x-axis limits (min, max), by default None.
|
|
211
|
+
y_limits : tuple, optional
|
|
212
|
+
The y-axis limits (min, max), by default None.
|
|
213
|
+
equal_axes : bool, optional
|
|
214
|
+
Whether to enforce the same range for the x and y axes, by default False.
|
|
215
|
+
bandwidth_adjust : float, optional
|
|
216
|
+
Bandwidth adjustment factor for the kernel density estimate of the marginal distributions. Default is 1.0.
|
|
217
|
+
color_palette : str or dict, optional
|
|
218
|
+
The color palette to use for the hue in the scatterplot.
|
|
219
|
+
|
|
220
|
+
Returns
|
|
221
|
+
-------
|
|
222
|
+
None
|
|
223
|
+
"""
|
|
224
|
+
x_data, y_data, x_units, y_units = self._extract_feature_data(feature)
|
|
225
|
+
|
|
226
|
+
# Determine equal axis limits if required
|
|
227
|
+
if equal_axes:
|
|
228
|
+
min_x = x_limits[0].to(x_units).magnitude if x_limits else x_data.min()
|
|
229
|
+
max_x = x_limits[1].to(x_units).magnitude if x_limits else x_data.max()
|
|
230
|
+
min_y = y_limits[0].to(y_units).magnitude if y_limits else y_data.min()
|
|
231
|
+
max_y = y_limits[1].to(y_units).magnitude if y_limits else y_data.max()
|
|
232
|
+
|
|
233
|
+
# Set common limits
|
|
234
|
+
min_val = min(min_x, min_y)
|
|
235
|
+
max_val = max(max_x, max_y)
|
|
236
|
+
x_limits = (min_val, max_val)
|
|
237
|
+
y_limits = (min_val, max_val)
|
|
238
|
+
|
|
239
|
+
with plt.style.context(mps):
|
|
240
|
+
if not log_plot:
|
|
241
|
+
# KDE + Scatterplot for linear plots
|
|
242
|
+
g = self._create_density_plot(x_data, y_data, bandwidth_adjust)
|
|
243
|
+
self._add_scatterplot(g, x_data, y_data, color_palette)
|
|
244
|
+
self._apply_axis_labels(g, feature, x_units, y_units)
|
|
245
|
+
self._apply_axis_limits(g, x_limits, y_limits, x_units, y_units)
|
|
246
|
+
else:
|
|
247
|
+
# Scatterplot only for log-scaled plots
|
|
248
|
+
fig, ax = plt.subplots()
|
|
249
|
+
sns.scatterplot(
|
|
250
|
+
x=x_data,
|
|
251
|
+
y=y_data,
|
|
252
|
+
hue=self.coincidence_dataframe["Label"],
|
|
253
|
+
palette=color_palette,
|
|
254
|
+
alpha=0.6,
|
|
255
|
+
ax=ax,
|
|
256
|
+
)
|
|
257
|
+
ax.set_xscale("log")
|
|
258
|
+
ax.set_yscale("log")
|
|
259
|
+
ax.set_xlabel(f"{feature} : {self.detector_names[0]} [{x_units:P}]")
|
|
260
|
+
ax.set_ylabel(f"{feature} : {self.detector_names[1]} [{y_units:P}]")
|
|
261
|
+
if x_limits:
|
|
262
|
+
ax.set_xlim([lim.to(x_units).magnitude for lim in x_limits])
|
|
263
|
+
if y_limits:
|
|
264
|
+
ax.set_ylim([lim.to(y_units).magnitude for lim in y_limits])
|
|
265
|
+
ax.legend()
|
|
266
|
+
|
|
267
|
+
plt.tight_layout()
|
|
268
|
+
if show:
|
|
269
|
+
plt.show()
|
|
270
|
+
|